feat: support wan2.1

This commit is contained in:
qiliguo
2025-10-16 10:45:15 +08:00
parent 386870431b
commit d1ad5eee12
5 changed files with 236 additions and 3 deletions

18
main.py
View File

@@ -4,6 +4,7 @@
import argparse
import json
import os
import sys
import re
import time
from datetime import datetime
@@ -12,6 +13,11 @@ from pathlib import Path
import torch
from functools import wraps
sys.path.append('./Wan2.1/')
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
_orig_load = torch.load
@wraps(_orig_load)
@@ -114,12 +120,20 @@ def main():
if not prompts:
raise ValueError("测试列表为空。")
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
model_dir_name = os.path.basename(os.path.realpath(model_path))
if model_dir_name.lower().startswith('wan'):
build_fn = wan_build_pipeline
generate_fn = wan_generate_one
else:
build_fn = build_pipeline
generate_fn = generate_one
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)