Files
2025-10-16 10:45:15 +08:00

176 lines
5.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import sys
import re
import time
from datetime import datetime
from pathlib import Path
import torch
from functools import wraps
sys.path.append('./Wan2.1/')
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
_orig_load = torch.load
@wraps(_orig_load)
def _load_patch(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _load_patch
from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
def safe_stem(text: str, maxlen: int = 60) -> str:
"""将提示词转为安全的文件名片段。"""
text = re.sub(r"\s+", "_", text.strip())
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
return (text[:maxlen] or "image").strip("_")
def load_prompts(json_path: Path):
"""
支持 JSON 结构:
1) ["prompt 1", "prompt 2", ...]
"""
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompts = []
if isinstance(data, list):
if all(isinstance(x, str) for x in data):
for s in data:
prompts.append({"prompt": s})
elif all(isinstance(x, dict) for x in data):
for obj in data:
if "prompt" not in obj:
raise ValueError("每个对象都需要包含 'prompt' 字段")
prompts.append(obj)
else:
raise ValueError("JSON 列表元素需全为字符串或全为对象。")
else:
raise ValueError("JSON 顶层必须是列表。")
return prompts
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
pipe = pipeline('text-to-video-synthesis', model_path)
# pipe.to("cuda")
return pipe
def generate_one(pipe, cfg: dict, out_dir: Path, index: int):
"""
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
支持字段:
- prompt (必需)
"""
prompt = cfg["prompt"]
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
stem = safe_stem(prompt)
filename = f"{index:03d}_{stem}_{stamp}.mp4"
out_path = out_dir / filename
start = time.time()
output_video_path = pipe({"text": prompt}, output_video=str(out_path))[OutputKeys.OUTPUT_VIDEO]
elapsed = time.time() - start
detail = {
"index": index,
"filename": filename,
"elapsed_seconds": round(elapsed, 6),
"prompt": prompt
}
return out_path, elapsed, detail
def main():
parser = argparse.ArgumentParser(
description="Stable Diffusion 基准与批量生成脚本JSON 结果)"
)
parser.add_argument("--model", required=True, help="模型路径或模型名(本地目录或 HF 仓库名)")
parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径")
parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json")
parser.add_argument("--outdir", required=True, help="图片输出目录")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
args, _ = parser.parse_known_args()
model_path = args.model
json_path = Path(args.json)
results_path = Path(args.results)
out_dir = Path(args.outdir)
out_dir.mkdir(parents=True, exist_ok=True)
results_path.parent.mkdir(parents=True, exist_ok=True)
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
prompts = load_prompts(json_path)
if not prompts:
raise ValueError("测试列表为空。")
model_dir_name = os.path.basename(os.path.realpath(model_path))
if model_dir_name.lower().startswith('wan'):
build_fn = wan_build_pipeline
generate_fn = wan_generate_one
else:
build_fn = build_pipeline
generate_fn = generate_one
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)
avg_latency = total_elapsed / len(records) if records else 0
# 结果 JSON 结构
result_obj = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"model": model_path,
"device": str(getattr(pipe, "device", "unknown")),
"dtype": "fp16" if dtype == torch.float16 else "fp32",
"count": len(records),
"total_elapsed_seconds": total_elapsed,
"avg_latency": avg_latency,
"cases": records
}
with open(results_path, "w", encoding="utf-8") as f:
json.dump(result_obj, f, ensure_ascii=False, indent=2)
print(f"\nAll done. vidoes: {len(records)}, total_elapsed: {total_elapsed:.3f}s, avg_latency: {avg_latency:.3f}")
print(f"Results JSON: {results_path}")
print(f"Images dir : {out_dir.resolve()}")
if __name__ == "__main__":
# Check what version of PyTorch is installed
print(torch.__version__)
# Check the current CUDA version being used
print("CUDA Version: ", torch.version.cuda)
# Check if CUDA is available and if so, print the device name
print("Device name:", torch.cuda.get_device_properties("cuda").name)
# Check if FlashAttention is available
print("FlashAttention available:", torch.backends.cuda.flash_sdp_enabled())
main()