feat: support wan2.1
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[submodule "wan21"]
|
||||||
|
path = Wan2.1
|
||||||
|
url = https://github.com/Wan-Video/Wan2.1
|
||||||
@@ -2,8 +2,11 @@ FROM cr.metax-tech.com/public-ai-release/maca/diffusers.training:maca.ai3.0.0.5-
|
|||||||
|
|
||||||
RUN /opt/conda/bin/pip install pytorch_lightning opencv-python-headless==4.10.0.84 imageio[ffmpeg] einops datasets==3.2.0 simplejson open_clip_torch==2.24.0 sortedcontainers modelscope av==11.0.0 addict
|
RUN /opt/conda/bin/pip install pytorch_lightning opencv-python-headless==4.10.0.84 imageio[ffmpeg] einops datasets==3.2.0 simplejson open_clip_torch==2.24.0 sortedcontainers modelscope av==11.0.0 addict
|
||||||
|
|
||||||
|
RUN /opt/conda/bin/pip install easydict dashscope
|
||||||
|
|
||||||
WORKDIR /opt/app
|
WORKDIR /opt/app
|
||||||
|
|
||||||
COPY ./main.py ./dataset.json ./
|
COPY ./main.py ./dataset.json ./wan_pipeline.py ./
|
||||||
|
COPY ./Wan2.1 ./Wan2.1
|
||||||
|
|
||||||
ENTRYPOINT ["/opt/conda/bin/python"]
|
ENTRYPOINT ["/opt/conda/bin/python"]
|
||||||
|
|||||||
1
Wan2.1
Submodule
1
Wan2.1
Submodule
Submodule Wan2.1 added at 7c81b2f27d
18
main.py
18
main.py
@@ -4,6 +4,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@@ -12,6 +13,11 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
|
||||||
|
sys.path.append('./Wan2.1/')
|
||||||
|
|
||||||
|
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
|
||||||
|
|
||||||
|
|
||||||
_orig_load = torch.load
|
_orig_load = torch.load
|
||||||
|
|
||||||
@wraps(_orig_load)
|
@wraps(_orig_load)
|
||||||
@@ -114,12 +120,20 @@ def main():
|
|||||||
if not prompts:
|
if not prompts:
|
||||||
raise ValueError("测试列表为空。")
|
raise ValueError("测试列表为空。")
|
||||||
|
|
||||||
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
|
|
||||||
|
|
||||||
|
model_dir_name = os.path.basename(os.path.realpath(model_path))
|
||||||
|
if model_dir_name.lower().startswith('wan'):
|
||||||
|
build_fn = wan_build_pipeline
|
||||||
|
generate_fn = wan_generate_one
|
||||||
|
else:
|
||||||
|
build_fn = build_pipeline
|
||||||
|
generate_fn = generate_one
|
||||||
|
|
||||||
|
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
|
||||||
records = []
|
records = []
|
||||||
total_start = time.time()
|
total_start = time.time()
|
||||||
for i, cfg in enumerate(prompts, 1):
|
for i, cfg in enumerate(prompts, 1):
|
||||||
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
|
||||||
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
||||||
records.append(detail)
|
records.append(detail)
|
||||||
total_elapsed = round(time.time() - total_start, 6)
|
total_elapsed = round(time.time() - total_start, 6)
|
||||||
|
|||||||
212
wan_pipeline.py
Normal file
212
wan_pipeline.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
from datetime import datetime
|
||||||
|
import time
|
||||||
|
import re
|
||||||
|
|
||||||
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
import wan
|
||||||
|
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||||
|
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||||
|
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||||
|
|
||||||
|
|
||||||
|
EXAMPLE_PROMPT = {
|
||||||
|
"t2v-1.3B": {
|
||||||
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
|
},
|
||||||
|
"t2v-14B": {
|
||||||
|
"prompt":
|
||||||
|
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||||
|
},
|
||||||
|
"t2i-14B": {
|
||||||
|
"prompt": "一个朴素端庄的美人",
|
||||||
|
},
|
||||||
|
"i2v-14B": {
|
||||||
|
"prompt":
|
||||||
|
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||||
|
"image":
|
||||||
|
"examples/i2v_input.JPG",
|
||||||
|
},
|
||||||
|
"flf2v-14B": {
|
||||||
|
"prompt":
|
||||||
|
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
|
||||||
|
"first_frame":
|
||||||
|
"examples/flf2v_input_first_frame.png",
|
||||||
|
"last_frame":
|
||||||
|
"examples/flf2v_input_last_frame.png",
|
||||||
|
},
|
||||||
|
"vace-1.3B": {
|
||||||
|
"src_ref_images":
|
||||||
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
},
|
||||||
|
"vace-14B": {
|
||||||
|
"src_ref_images":
|
||||||
|
'examples/girl.png,examples/snake.png',
|
||||||
|
"prompt":
|
||||||
|
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class WanArgs:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.task ='t2v-1.3B'
|
||||||
|
self.size ='480*832'
|
||||||
|
self.frame_num =81
|
||||||
|
self.ckpt_dir =''
|
||||||
|
self.offload_model =None
|
||||||
|
self.ulysses_size =1
|
||||||
|
self.ring_size =1
|
||||||
|
self.t5_fsdp =False
|
||||||
|
self.t5_cpu =False
|
||||||
|
self.dit_fsdp =False
|
||||||
|
self.save_file =None
|
||||||
|
self.src_video =None
|
||||||
|
self.src_mask =None
|
||||||
|
self.src_ref_images =None
|
||||||
|
self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.'
|
||||||
|
self.use_prompt_extend =False
|
||||||
|
self.prompt_extend_method ='local_qwen'
|
||||||
|
self.prompt_extend_model =None
|
||||||
|
self.prompt_extend_target_lang ='zh'
|
||||||
|
self.image =None
|
||||||
|
self.first_frame =None
|
||||||
|
self.last_frame =None
|
||||||
|
self.sample_solver ='unipc'
|
||||||
|
self.sample_steps =50
|
||||||
|
self.sample_shift =5.0
|
||||||
|
self.sample_guide_scale =5.0
|
||||||
|
self.base_seed = 0
|
||||||
|
|
||||||
|
args = WanArgs()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def _init_logging(rank):
|
||||||
|
# logging
|
||||||
|
if rank == 0:
|
||||||
|
# set format
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||||
|
else:
|
||||||
|
logging.basicConfig(level=logging.ERROR)
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
|
||||||
|
rank = int(os.getenv("RANK", 0))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||||
|
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||||
|
device_id = local_rank
|
||||||
|
_init_logging(rank)
|
||||||
|
|
||||||
|
args.ckpt_dir = model_path
|
||||||
|
|
||||||
|
if args.offload_model is None:
|
||||||
|
args.offload_model = False if world_size > 1 else True
|
||||||
|
logging.info(
|
||||||
|
f"offload_model is not specified, set to {args.offload_model}.")
|
||||||
|
if world_size > 1:
|
||||||
|
torch.cuda.set_device(device_id)
|
||||||
|
dist.init_process_group(
|
||||||
|
backend="nccl",
|
||||||
|
init_method="env://",
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size)
|
||||||
|
else:
|
||||||
|
assert not (
|
||||||
|
args.t5_fsdp or args.dit_fsdp
|
||||||
|
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
|
||||||
|
assert not (
|
||||||
|
args.ulysses_size > 1 or args.ring_size > 1
|
||||||
|
), f"context parallel are not supported in non-distributed environments."
|
||||||
|
cfg = WAN_CONFIGS[args.task]
|
||||||
|
logging.info(f"Generation job args: {args}")
|
||||||
|
logging.info(f"Generation model config: {cfg}")
|
||||||
|
|
||||||
|
if dist.is_initialized():
|
||||||
|
base_seed = [args.base_seed] if rank == 0 else [None]
|
||||||
|
dist.broadcast_object_list(base_seed, src=0)
|
||||||
|
args.base_seed = base_seed[0]
|
||||||
|
|
||||||
|
# logging.info(f"Input prompt: {args.prompt}")
|
||||||
|
logging.info("Creating WanT2V pipeline.")
|
||||||
|
wan_t2v = wan.WanT2V(
|
||||||
|
config=cfg,
|
||||||
|
checkpoint_dir=args.ckpt_dir,
|
||||||
|
device_id=device_id,
|
||||||
|
rank=rank,
|
||||||
|
t5_fsdp=args.t5_fsdp,
|
||||||
|
dit_fsdp=args.dit_fsdp,
|
||||||
|
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||||
|
t5_cpu=args.t5_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wan_t2v
|
||||||
|
|
||||||
|
def safe_stem(text: str, maxlen: int = 60) -> str:
|
||||||
|
"""将提示词转为安全的文件名片段。"""
|
||||||
|
text = re.sub(r"\s+", "_", text.strip())
|
||||||
|
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
|
||||||
|
return (text[:maxlen] or "image").strip("_")
|
||||||
|
|
||||||
|
def generate_one(pipe, cfg, out_dir: Path, index: int):
|
||||||
|
"""
|
||||||
|
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
|
||||||
|
支持字段:
|
||||||
|
- prompt (必需)
|
||||||
|
"""
|
||||||
|
|
||||||
|
global args
|
||||||
|
|
||||||
|
prompt = cfg["prompt"]
|
||||||
|
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||||
|
stem = safe_stem(prompt)
|
||||||
|
filename = f"{index:03d}_{stem}_{stamp}.mp4"
|
||||||
|
out_path = out_dir / filename
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
video = pipe.generate(
|
||||||
|
prompt,
|
||||||
|
size=SIZE_CONFIGS[args.size],
|
||||||
|
frame_num=args.frame_num,
|
||||||
|
shift=args.sample_shift,
|
||||||
|
sample_solver=args.sample_solver,
|
||||||
|
sampling_steps=args.sample_steps,
|
||||||
|
guide_scale=args.sample_guide_scale,
|
||||||
|
seed=args.base_seed,
|
||||||
|
offload_model=args.offload_model)
|
||||||
|
|
||||||
|
elapsed = time.time() - start
|
||||||
|
wan_cfg = WAN_CONFIGS[args.task]
|
||||||
|
cache_video(
|
||||||
|
tensor=video[None],
|
||||||
|
save_file=out_path,
|
||||||
|
fps=wan_cfg.sample_fps,
|
||||||
|
nrow=1,
|
||||||
|
normalize=True,
|
||||||
|
value_range=(-1, 1))
|
||||||
|
|
||||||
|
detail = {
|
||||||
|
"index": index,
|
||||||
|
"filename": filename,
|
||||||
|
"elapsed_seconds": round(elapsed, 6),
|
||||||
|
"prompt": prompt
|
||||||
|
}
|
||||||
|
return out_path, elapsed, detail
|
||||||
|
|
||||||
Reference in New Issue
Block a user