feat: support wan2.1
This commit is contained in:
3
.gitmodules
vendored
Normal file
3
.gitmodules
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
[submodule "wan21"]
|
||||
path = Wan2.1
|
||||
url = https://github.com/Wan-Video/Wan2.1
|
||||
@@ -2,8 +2,11 @@ FROM cr.metax-tech.com/public-ai-release/maca/diffusers.training:maca.ai3.0.0.5-
|
||||
|
||||
RUN /opt/conda/bin/pip install pytorch_lightning opencv-python-headless==4.10.0.84 imageio[ffmpeg] einops datasets==3.2.0 simplejson open_clip_torch==2.24.0 sortedcontainers modelscope av==11.0.0 addict
|
||||
|
||||
RUN /opt/conda/bin/pip install easydict dashscope
|
||||
|
||||
WORKDIR /opt/app
|
||||
|
||||
COPY ./main.py ./dataset.json ./
|
||||
COPY ./main.py ./dataset.json ./wan_pipeline.py ./
|
||||
COPY ./Wan2.1 ./Wan2.1
|
||||
|
||||
ENTRYPOINT ["/opt/conda/bin/python"]
|
||||
|
||||
1
Wan2.1
Submodule
1
Wan2.1
Submodule
Submodule Wan2.1 added at 7c81b2f27d
18
main.py
18
main.py
@@ -4,6 +4,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -12,6 +13,11 @@ from pathlib import Path
|
||||
import torch
|
||||
from functools import wraps
|
||||
|
||||
sys.path.append('./Wan2.1/')
|
||||
|
||||
from wan_pipeline import build_pipeline as wan_build_pipeline, generate_one as wan_generate_one
|
||||
|
||||
|
||||
_orig_load = torch.load
|
||||
|
||||
@wraps(_orig_load)
|
||||
@@ -114,12 +120,20 @@ def main():
|
||||
if not prompts:
|
||||
raise ValueError("测试列表为空。")
|
||||
|
||||
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
|
||||
|
||||
model_dir_name = os.path.basename(os.path.realpath(model_path))
|
||||
if model_dir_name.lower().startswith('wan'):
|
||||
build_fn = wan_build_pipeline
|
||||
generate_fn = wan_generate_one
|
||||
else:
|
||||
build_fn = build_pipeline
|
||||
generate_fn = generate_one
|
||||
|
||||
pipe = build_fn(model_path=model_path, device=args.device, dtype=dtype)
|
||||
records = []
|
||||
total_start = time.time()
|
||||
for i, cfg in enumerate(prompts, 1):
|
||||
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
|
||||
out_path, elapsed, detail = generate_fn(pipe, cfg, out_dir, i)
|
||||
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
|
||||
records.append(detail)
|
||||
total_elapsed = round(time.time() - total_start, 6)
|
||||
|
||||
212
wan_pipeline.py
Normal file
212
wan_pipeline.py
Normal file
@@ -0,0 +1,212 @@
|
||||
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
import time
|
||||
import re
|
||||
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from PIL import Image
|
||||
|
||||
import wan
|
||||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||||
from wan.utils.utils import cache_image, cache_video, str2bool
|
||||
|
||||
|
||||
EXAMPLE_PROMPT = {
|
||||
"t2v-1.3B": {
|
||||
"prompt":
|
||||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
"t2v-14B": {
|
||||
"prompt":
|
||||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||||
},
|
||||
"t2i-14B": {
|
||||
"prompt": "一个朴素端庄的美人",
|
||||
},
|
||||
"i2v-14B": {
|
||||
"prompt":
|
||||
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||||
"image":
|
||||
"examples/i2v_input.JPG",
|
||||
},
|
||||
"flf2v-14B": {
|
||||
"prompt":
|
||||
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
|
||||
"first_frame":
|
||||
"examples/flf2v_input_first_frame.png",
|
||||
"last_frame":
|
||||
"examples/flf2v_input_last_frame.png",
|
||||
},
|
||||
"vace-1.3B": {
|
||||
"src_ref_images":
|
||||
'examples/girl.png,examples/snake.png',
|
||||
"prompt":
|
||||
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
},
|
||||
"vace-14B": {
|
||||
"src_ref_images":
|
||||
'examples/girl.png,examples/snake.png',
|
||||
"prompt":
|
||||
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class WanArgs:
|
||||
def __init__(self) -> None:
|
||||
self.task ='t2v-1.3B'
|
||||
self.size ='480*832'
|
||||
self.frame_num =81
|
||||
self.ckpt_dir =''
|
||||
self.offload_model =None
|
||||
self.ulysses_size =1
|
||||
self.ring_size =1
|
||||
self.t5_fsdp =False
|
||||
self.t5_cpu =False
|
||||
self.dit_fsdp =False
|
||||
self.save_file =None
|
||||
self.src_video =None
|
||||
self.src_mask =None
|
||||
self.src_ref_images =None
|
||||
self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.'
|
||||
self.use_prompt_extend =False
|
||||
self.prompt_extend_method ='local_qwen'
|
||||
self.prompt_extend_model =None
|
||||
self.prompt_extend_target_lang ='zh'
|
||||
self.image =None
|
||||
self.first_frame =None
|
||||
self.last_frame =None
|
||||
self.sample_solver ='unipc'
|
||||
self.sample_steps =50
|
||||
self.sample_shift =5.0
|
||||
self.sample_guide_scale =5.0
|
||||
self.base_seed = 0
|
||||
|
||||
args = WanArgs()
|
||||
|
||||
|
||||
|
||||
def _init_logging(rank):
|
||||
# logging
|
||||
if rank == 0:
|
||||
# set format
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||||
else:
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
|
||||
|
||||
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
device_id = local_rank
|
||||
_init_logging(rank)
|
||||
|
||||
args.ckpt_dir = model_path
|
||||
|
||||
if args.offload_model is None:
|
||||
args.offload_model = False if world_size > 1 else True
|
||||
logging.info(
|
||||
f"offload_model is not specified, set to {args.offload_model}.")
|
||||
if world_size > 1:
|
||||
torch.cuda.set_device(device_id)
|
||||
dist.init_process_group(
|
||||
backend="nccl",
|
||||
init_method="env://",
|
||||
rank=rank,
|
||||
world_size=world_size)
|
||||
else:
|
||||
assert not (
|
||||
args.t5_fsdp or args.dit_fsdp
|
||||
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
|
||||
assert not (
|
||||
args.ulysses_size > 1 or args.ring_size > 1
|
||||
), f"context parallel are not supported in non-distributed environments."
|
||||
cfg = WAN_CONFIGS[args.task]
|
||||
logging.info(f"Generation job args: {args}")
|
||||
logging.info(f"Generation model config: {cfg}")
|
||||
|
||||
if dist.is_initialized():
|
||||
base_seed = [args.base_seed] if rank == 0 else [None]
|
||||
dist.broadcast_object_list(base_seed, src=0)
|
||||
args.base_seed = base_seed[0]
|
||||
|
||||
# logging.info(f"Input prompt: {args.prompt}")
|
||||
logging.info("Creating WanT2V pipeline.")
|
||||
wan_t2v = wan.WanT2V(
|
||||
config=cfg,
|
||||
checkpoint_dir=args.ckpt_dir,
|
||||
device_id=device_id,
|
||||
rank=rank,
|
||||
t5_fsdp=args.t5_fsdp,
|
||||
dit_fsdp=args.dit_fsdp,
|
||||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||||
t5_cpu=args.t5_cpu,
|
||||
)
|
||||
|
||||
return wan_t2v
|
||||
|
||||
def safe_stem(text: str, maxlen: int = 60) -> str:
|
||||
"""将提示词转为安全的文件名片段。"""
|
||||
text = re.sub(r"\s+", "_", text.strip())
|
||||
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
|
||||
return (text[:maxlen] or "image").strip("_")
|
||||
|
||||
def generate_one(pipe, cfg, out_dir: Path, index: int):
|
||||
"""
|
||||
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
|
||||
支持字段:
|
||||
- prompt (必需)
|
||||
"""
|
||||
|
||||
global args
|
||||
|
||||
prompt = cfg["prompt"]
|
||||
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
stem = safe_stem(prompt)
|
||||
filename = f"{index:03d}_{stem}_{stamp}.mp4"
|
||||
out_path = out_dir / filename
|
||||
|
||||
start = time.time()
|
||||
video = pipe.generate(
|
||||
prompt,
|
||||
size=SIZE_CONFIGS[args.size],
|
||||
frame_num=args.frame_num,
|
||||
shift=args.sample_shift,
|
||||
sample_solver=args.sample_solver,
|
||||
sampling_steps=args.sample_steps,
|
||||
guide_scale=args.sample_guide_scale,
|
||||
seed=args.base_seed,
|
||||
offload_model=args.offload_model)
|
||||
|
||||
elapsed = time.time() - start
|
||||
wan_cfg = WAN_CONFIGS[args.task]
|
||||
cache_video(
|
||||
tensor=video[None],
|
||||
save_file=out_path,
|
||||
fps=wan_cfg.sample_fps,
|
||||
nrow=1,
|
||||
normalize=True,
|
||||
value_range=(-1, 1))
|
||||
|
||||
detail = {
|
||||
"index": index,
|
||||
"filename": filename,
|
||||
"elapsed_seconds": round(elapsed, 6),
|
||||
"prompt": prompt
|
||||
}
|
||||
return out_path, elapsed, detail
|
||||
|
||||
Reference in New Issue
Block a user