feat: support wan2.1
This commit is contained in:
18
main.py
18
main.py
@@ -4,6 +4,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -12,6 +13,11 @@ from pathlib import Path
|
||||
import torch
|
||||
from functools import wraps
|
||||
|
||||
sys.path.append('./Wan2.1/')
|
||||
|
||||
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
|
||||
|
||||
|
||||
_orig_load = torch.load
|
||||
|
||||
@wraps(_orig_load)
|
||||
@@ -114,12 +120,20 @@ def main():
|
||||
if not prompts:
|
||||
raise ValueError("测试列表为空。")
|
||||
|
||||
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
|
||||
|
||||
model_dir_name = os.path.basename(os.path.realpath(model_path))
|
||||
if model_dir_name.lower().startswith('wan'):
|
||||
build_fn = wan_build_pipeline
|
||||
generate_fn = wan_generate_one
|
||||
else:
|
||||
build_fn = build_pipeline
|
||||
generate_fn = generate_one
|
||||
|
||||
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
|
||||
records = []
|
||||
total_start = time.time()
|
||||
for i, cfg in enumerate(prompts, 1):
|
||||
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
||||
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
|
||||
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
||||
records.append(detail)
|
||||
total_elapsed = round(time.time() - total_start, 6)
|
||||
|
||||
Reference in New Issue
Block a user