From d1ad5eee12dc3673fe5c1afce301df69d1eb8969 Mon Sep 17 00:00:00 2001 From: qiliguo Date: Thu, 16 Oct 2025 10:45:15 +0800 Subject: [PATCH] feat: support wan2.1 --- .gitmodules | 3 + Dockerfile | 5 +- Wan2.1 | 1 + main.py | 18 +++- wan_pipeline.py | 212 ++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 236 insertions(+), 3 deletions(-) create mode 100644 .gitmodules create mode 160000 Wan2.1 create mode 100644 wan_pipeline.py diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..4751634 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "wan21"] + path = Wan2.1 + url = https://github.com/Wan-Video/Wan2.1 diff --git a/Dockerfile b/Dockerfile index e6a6929..cb6d8f8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -2,8 +2,11 @@ FROM cr.metax-tech.com/public-ai-release/maca/diffusers.training:maca.ai3.0.0.5- RUN /opt/conda/bin/pip install pytorch_lightning opencv-python-headless==4.10.0.84 imageio[ffmpeg] einops datasets==3.2.0 simplejson open_clip_torch==2.24.0 sortedcontainers modelscope av==11.0.0 addict +RUN /opt/conda/bin/pip install easydict dashscope + WORKDIR /opt/app -COPY ./main.py ./dataset.json ./ +COPY ./main.py ./dataset.json ./wan_pipeline.py ./ +COPY ./Wan2.1 ./Wan2.1 ENTRYPOINT ["/opt/conda/bin/python"] diff --git a/Wan2.1 b/Wan2.1 new file mode 160000 index 0000000..7c81b2f --- /dev/null +++ b/Wan2.1 @@ -0,0 +1 @@ +Subproject commit 7c81b2f27defa56c7e627a4b6717c8f2292eee58 diff --git a/main.py b/main.py index 8d888f7..08b6d1c 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ import argparse import json import os +import sys import re import time from datetime import datetime @@ -12,6 +13,11 @@ from pathlib import Path import torch from functools import wraps +sys.path.append('./Wan2.1/') + +from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one + + _orig_load = torch.load @wraps(_orig_load) @@ -114,12 +120,20 @@ def main(): if not prompts: raise ValueError("测试列表为空。") - pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype) + model_dir_name = os.path.basename(os.path.realpath(model_path)) + if model_dir_name.lower().startswith('wan'): + build_fn = wan_build_pipeline + generate_fn = wan_generate_one + else: + build_fn = build_pipeline + generate_fn = generate_one + + pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype) records = [] total_start = time.time() for i, cfg in enumerate(prompts, 1): - out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i) + out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i) print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s") records.append(detail) total_elapsed = round(time.time() - total_start, 6) diff --git a/wan_pipeline.py b/wan_pipeline.py new file mode 100644 index 0000000..3e12ac3 --- /dev/null +++ b/wan_pipeline.py @@ -0,0 +1,212 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import argparse +import logging +import os +from pathlib import Path +import sys +import warnings +from datetime import datetime +import time +import re + +warnings.filterwarnings('ignore') + +import torch +import torch.distributed as dist +from PIL import Image + +import wan +from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS +from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander +from wan.utils.utils import cache_image, cache_video, str2bool + + +EXAMPLE_PROMPT = { + "t2v-1.3B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2v-14B": { + "prompt": + "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", + }, + "t2i-14B": { + "prompt": "一个朴素端庄的美人", + }, + "i2v-14B": { + "prompt": + "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", + "image": + "examples/i2v_input.JPG", + }, + "flf2v-14B": { + "prompt": + "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", + "first_frame": + "examples/flf2v_input_first_frame.png", + "last_frame": + "examples/flf2v_input_last_frame.png", + }, + "vace-1.3B": { + "src_ref_images": + 'examples/girl.png,examples/snake.png', + "prompt": + "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + }, + "vace-14B": { + "src_ref_images": + 'examples/girl.png,examples/snake.png', + "prompt": + "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" + } +} + + +class WanArgs: + def __init__(self) -> None: + self.task ='t2v-1.3B' + self.size ='480*832' + self.frame_num =81 + self.ckpt_dir ='' + self.offload_model =None + self.ulysses_size =1 + self.ring_size =1 + self.t5_fsdp =False + self.t5_cpu =False + self.dit_fsdp =False + self.save_file =None + self.src_video =None + self.src_mask =None + self.src_ref_images =None + self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.' + self.use_prompt_extend =False + self.prompt_extend_method ='local_qwen' + self.prompt_extend_model =None + self.prompt_extend_target_lang ='zh' + self.image =None + self.first_frame =None + self.last_frame =None + self.sample_solver ='unipc' + self.sample_steps =50 + self.sample_shift =5.0 + self.sample_guide_scale =5.0 + self.base_seed = 0 + +args = WanArgs() + + + +def _init_logging(rank): + # logging + if rank == 0: + # set format + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] %(levelname)s: %(message)s", + handlers=[logging.StreamHandler(stream=sys.stdout)]) + else: + logging.basicConfig(level=logging.ERROR) + + +def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16): + rank = int(os.getenv("RANK", 0)) + world_size = int(os.getenv("WORLD_SIZE", 1)) + local_rank = int(os.getenv("LOCAL_RANK", 0)) + device_id = local_rank + _init_logging(rank) + + args.ckpt_dir = model_path + + if args.offload_model is None: + args.offload_model = False if world_size > 1 else True + logging.info( + f"offload_model is not specified, set to {args.offload_model}.") + if world_size > 1: + torch.cuda.set_device(device_id) + dist.init_process_group( + backend="nccl", + init_method="env://", + rank=rank, + world_size=world_size) + else: + assert not ( + args.t5_fsdp or args.dit_fsdp + ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." + assert not ( + args.ulysses_size > 1 or args.ring_size > 1 + ), f"context parallel are not supported in non-distributed environments." + cfg = WAN_CONFIGS[args.task] + logging.info(f"Generation job args: {args}") + logging.info(f"Generation model config: {cfg}") + + if dist.is_initialized(): + base_seed = [args.base_seed] if rank == 0 else [None] + dist.broadcast_object_list(base_seed, src=0) + args.base_seed = base_seed[0] + + # logging.info(f"Input prompt: {args.prompt}") + logging.info("Creating WanT2V pipeline.") + wan_t2v = wan.WanT2V( + config=cfg, + checkpoint_dir=args.ckpt_dir, + device_id=device_id, + rank=rank, + t5_fsdp=args.t5_fsdp, + dit_fsdp=args.dit_fsdp, + use_usp=(args.ulysses_size > 1 or args.ring_size > 1), + t5_cpu=args.t5_cpu, + ) + + return wan_t2v + +def safe_stem(text: str, maxlen: int = 60) -> str: + """将提示词转为安全的文件名片段。""" + text = re.sub(r"\s+", "_", text.strip()) + text = re.sub(r"[^A-Za-z0-9_\-]+", "", text) + return (text[:maxlen] or "image").strip("_") + +def generate_one(pipe, cfg, out_dir: Path, index: int): + """ + 依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数) + 支持字段: + - prompt (必需) + """ + + global args + + prompt = cfg["prompt"] + stamp = datetime.now().strftime("%Y%m%d-%H%M%S") + stem = safe_stem(prompt) + filename = f"{index:03d}_{stem}_{stamp}.mp4" + out_path = out_dir / filename + + start = time.time() + video = pipe.generate( + prompt, + size=SIZE_CONFIGS[args.size], + frame_num=args.frame_num, + shift=args.sample_shift, + sample_solver=args.sample_solver, + sampling_steps=args.sample_steps, + guide_scale=args.sample_guide_scale, + seed=args.base_seed, + offload_model=args.offload_model) + + elapsed = time.time() - start + wan_cfg = WAN_CONFIGS[args.task] + cache_video( + tensor=video[None], + save_file=out_path, + fps=wan_cfg.sample_fps, + nrow=1, + normalize=True, + value_range=(-1, 1)) + + detail = { + "index": index, + "filename": filename, + "elapsed_seconds": round(elapsed, 6), + "prompt": prompt + } + return out_path, elapsed, detail +