feat: support wan2.1

This commit is contained in:
qiliguo
2025-10-16 10:45:15 +08:00
parent 386870431b
commit d1ad5eee12
5 changed files with 236 additions and 3 deletions

3
.gitmodules vendored Normal file
View File

@@ -0,0 +1,3 @@
[submodule "wan21"]
path = Wan2.1
url = https://github.com/Wan-Video/Wan2.1

View File

@@ -2,8 +2,11 @@ FROM cr.metax-tech.com/public-ai-release/maca/diffusers.training:maca.ai3.0.0.5-
RUN /opt/conda/bin/pip install pytorch_lightning opencv-python-headless==4.10.0.84 imageio[ffmpeg] einops datasets==3.2.0 simplejson open_clip_torch==2.24.0 sortedcontainers modelscope av==11.0.0 addict
RUN /opt/conda/bin/pip install easydict dashscope
WORKDIR /opt/app
COPY ./main.py ./dataset.json ./
COPY ./main.py ./dataset.json ./wan_pipeline.py ./
COPY ./Wan2.1 ./Wan2.1
ENTRYPOINT ["/opt/conda/bin/python"]

1
Wan2.1 Submodule

Submodule Wan2.1 added at 7c81b2f27d

18
main.py
View File

@@ -4,6 +4,7 @@
import argparse
import json
import os
import sys
import re
import time
from datetime import datetime
@@ -12,6 +13,11 @@ from pathlib import Path
import torch
from functools import wraps
sys.path.append('./Wan2.1/')
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
_orig_load = torch.load
@wraps(_orig_load)
@@ -114,12 +120,20 @@ def main():
if not prompts:
raise ValueError("测试列表为空。")
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
model_dir_name = os.path.basename(os.path.realpath(model_path))
if model_dir_name.lower().startswith('wan'):
build_fn = wan_build_pipeline
generate_fn = wan_generate_one
else:
build_fn = build_pipeline
generate_fn = generate_one
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)

212
wan_pipeline.py Normal file
View File

@@ -0,0 +1,212 @@
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import argparse
import logging
import os
from pathlib import Path
import sys
import warnings
from datetime import datetime
import time
import re
warnings.filterwarnings('ignore')
import torch
import torch.distributed as dist
from PIL import Image
import wan
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
from wan.utils.utils import cache_image, cache_video, str2bool
EXAMPLE_PROMPT = {
"t2v-1.3B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2v-14B": {
"prompt":
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
},
"t2i-14B": {
"prompt": "一个朴素端庄的美人",
},
"i2v-14B": {
"prompt":
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"image":
"examples/i2v_input.JPG",
},
"flf2v-14B": {
"prompt":
"CG动画风格一只蓝色的小鸟从地面起飞煽动翅膀。小鸟羽毛细腻胸前有独特的花纹背景是蓝天白云阳光明媚。镜跟随小鸟向上移动展现出小鸟飞翔的姿态和天空的广阔。近景仰视视角。",
"first_frame":
"examples/flf2v_input_first_frame.png",
"last_frame":
"examples/flf2v_input_last_frame.png",
},
"vace-1.3B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
},
"vace-14B": {
"src_ref_images":
'examples/girl.png,examples/snake.png',
"prompt":
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
}
}
class WanArgs:
def __init__(self) -> None:
self.task ='t2v-1.3B'
self.size ='480*832'
self.frame_num =81
self.ckpt_dir =''
self.offload_model =None
self.ulysses_size =1
self.ring_size =1
self.t5_fsdp =False
self.t5_cpu =False
self.dit_fsdp =False
self.save_file =None
self.src_video =None
self.src_mask =None
self.src_ref_images =None
self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.'
self.use_prompt_extend =False
self.prompt_extend_method ='local_qwen'
self.prompt_extend_model =None
self.prompt_extend_target_lang ='zh'
self.image =None
self.first_frame =None
self.last_frame =None
self.sample_solver ='unipc'
self.sample_steps =50
self.sample_shift =5.0
self.sample_guide_scale =5.0
self.base_seed = 0
args = WanArgs()
def _init_logging(rank):
# logging
if rank == 0:
# set format
logging.basicConfig(
level=logging.INFO,
format="[%(asctime)s] %(levelname)s: %(message)s",
handlers=[logging.StreamHandler(stream=sys.stdout)])
else:
logging.basicConfig(level=logging.ERROR)
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
rank = int(os.getenv("RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
local_rank = int(os.getenv("LOCAL_RANK", 0))
device_id = local_rank
_init_logging(rank)
args.ckpt_dir = model_path
if args.offload_model is None:
args.offload_model = False if world_size > 1 else True
logging.info(
f"offload_model is not specified, set to {args.offload_model}.")
if world_size > 1:
torch.cuda.set_device(device_id)
dist.init_process_group(
backend="nccl",
init_method="env://",
rank=rank,
world_size=world_size)
else:
assert not (
args.t5_fsdp or args.dit_fsdp
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
assert not (
args.ulysses_size > 1 or args.ring_size > 1
), f"context parallel are not supported in non-distributed environments."
cfg = WAN_CONFIGS[args.task]
logging.info(f"Generation job args: {args}")
logging.info(f"Generation model config: {cfg}")
if dist.is_initialized():
base_seed = [args.base_seed] if rank == 0 else [None]
dist.broadcast_object_list(base_seed, src=0)
args.base_seed = base_seed[0]
# logging.info(f"Input prompt: {args.prompt}")
logging.info("Creating WanT2V pipeline.")
wan_t2v = wan.WanT2V(
config=cfg,
checkpoint_dir=args.ckpt_dir,
device_id=device_id,
rank=rank,
t5_fsdp=args.t5_fsdp,
dit_fsdp=args.dit_fsdp,
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
t5_cpu=args.t5_cpu,
)
return wan_t2v
def safe_stem(text: str, maxlen: int = 60) -> str:
"""将提示词转为安全的文件名片段。"""
text = re.sub(r"\s+", "_", text.strip())
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
return (text[:maxlen] or "image").strip("_")
def generate_one(pipe, cfg, out_dir: Path, index: int):
"""
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
支持字段:
- prompt (必需)
"""
global args
prompt = cfg["prompt"]
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
stem = safe_stem(prompt)
filename = f"{index:03d}_{stem}_{stamp}.mp4"
out_path = out_dir / filename
start = time.time()
video = pipe.generate(
prompt,
size=SIZE_CONFIGS[args.size],
frame_num=args.frame_num,
shift=args.sample_shift,
sample_solver=args.sample_solver,
sampling_steps=args.sample_steps,
guide_scale=args.sample_guide_scale,
seed=args.base_seed,
offload_model=args.offload_model)
elapsed = time.time() - start
wan_cfg = WAN_CONFIGS[args.task]
cache_video(
tensor=video[None],
save_file=out_path,
fps=wan_cfg.sample_fps,
nrow=1,
normalize=True,
value_range=(-1, 1))
detail = {
"index": index,
"filename": filename,
"elapsed_seconds": round(elapsed, 6),
"prompt": prompt
}
return out_path, elapsed, detail