#!/usr/bin/env python3 # -*- coding: utf-8 -*- import argparse import json import os import re import time from datetime import datetime from pathlib import Path import patch from modelscope.pipelines import pipeline from modelscope.outputs import OutputKeys import torch from diffusers import DiffusionPipeline from diffusers.utils import export_to_video def safe_stem(text: str, maxlen: int = 60) -> str: """将提示词转为安全的文件名片段。""" text = re.sub(r"\s+", "_", text.strip()) text = re.sub(r"[^A-Za-z0-9_\-]+", "", text) return (text[:maxlen] or "image").strip("_") def load_prompts(json_path: Path): """ 支持 JSON 结构: 1) ["prompt 1", "prompt 2", ...] """ with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) prompts = [] if isinstance(data, list): if all(isinstance(x, str) for x in data): for s in data: prompts.append({"prompt": s}) elif all(isinstance(x, dict) for x in data): for obj in data: if "prompt" not in obj: raise ValueError("每个对象都需要包含 'prompt' 字段") prompts.append(obj) else: raise ValueError("JSON 列表元素需全为字符串或全为对象。") else: raise ValueError("JSON 顶层必须是列表。") return prompts def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16, model_type: str = "text-to-video-synthesis"): if model_type == "text-to-video-synthesis": pipe = pipeline('text-to-video-synthesis', model_path, device=device) elif model_type == "text-to-video-ms": pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype) pipe.enable_model_cpu_offload() # 省显存 pipe.enable_vae_slicing() else: raise ValueError(f"不支持的模型类型: {model_type}") return pipe def generate_one(pipe, cfg: dict, out_dir: Path, index: int, model_type: str = "text-to-video-synthesis"): """ 依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数) 支持字段: - prompt (必需) """ prompt = cfg["prompt"] stamp = datetime.now().strftime("%Y%m%d-%H%M%S") stem = safe_stem(prompt) filename = f"{index:03d}_{stem}_{stamp}.mp4" out_path = out_dir / filename start = time.time() if model_type == "text-to-video-synthesis": output_video_path = pipe({"text": prompt}, output_video=str(out_path))[OutputKeys.OUTPUT_VIDEO] elif model_type == "text-to-video-ms": frames = pipe(prompt, num_frames=16).frames[0] export_to_video(frames, str(out_path)) else: raise ValueError(f"不支持的模型类型: {model_type}") elapsed = time.time() - start detail = { "index": index, "filename": filename, "elapsed_seconds": round(elapsed, 6), "prompt": prompt } return out_path, elapsed, detail def main(): parser = argparse.ArgumentParser( description="Stable Diffusion 基准与批量生成脚本(JSON 结果)" ) parser.add_argument("--model", required=True, help="模型路径或模型名(本地目录或 HF 仓库名)") parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径") parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json)") parser.add_argument("--outdir", required=True, help="图片输出目录") parser.add_argument("--device", default="cuda", help="推理设备") parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度") parser.add_argument("--model_type", default="text-to-video-synthesis", choices=["text-to-video-synthesis", "text-to-video-ms"], help="模型类型") args, _ = parser.parse_known_args() model_path = args.model json_path = Path(args.json) results_path = Path(args.results) out_dir = Path(args.outdir) out_dir.mkdir(parents=True, exist_ok=True) results_path.parent.mkdir(parents=True, exist_ok=True) dtype = torch.float16 if args.dtype == "fp16" else torch.float32 prompts = load_prompts(json_path) if not prompts: raise ValueError("测试列表为空。") pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype, model_type=args.model_type) records = [] total_start = time.time() for i, cfg in enumerate(prompts, 1): out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i, model_type=args.model_type) print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s") records.append(detail) total_elapsed = round(time.time() - total_start, 6) avg_latency = total_elapsed / len(records) if records else 0 # 结果 JSON 结构 result_obj = { "timestamp": datetime.now().isoformat(timespec="seconds"), "model": model_path, "device": str(getattr(pipe, "device", "unknown")), "dtype": "fp16" if dtype == torch.float16 else "fp32", "count": len(records), "total_elapsed_seconds": total_elapsed, "avg_latency": avg_latency, "cases": records } with open(results_path, "w", encoding="utf-8") as f: json.dump(result_obj, f, ensure_ascii=False, indent=2) print(f"\nAll done. vidoes: {len(records)}, total_elapsed: {total_elapsed:.3f}s, avg_latency: {avg_latency:.3f}") print(f"Results JSON: {results_path}") print(f"Output dir : {out_dir.resolve()}") if __name__ == "__main__": main()