213 lines
7.8 KiB
Python
213 lines
7.8 KiB
Python
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
||
import argparse
|
||
import logging
|
||
import os
|
||
from pathlib import Path
|
||
import sys
|
||
import warnings
|
||
from datetime import datetime
|
||
import time
|
||
import re
|
||
|
||
warnings.filterwarnings('ignore')
|
||
|
||
import torch
|
||
import torch.distributed as dist
|
||
from PIL import Image
|
||
|
||
import wan
|
||
from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS
|
||
from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander
|
||
from wan.utils.utils import cache_image, cache_video, str2bool
|
||
|
||
|
||
EXAMPLE_PROMPT = {
|
||
"t2v-1.3B": {
|
||
"prompt":
|
||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||
},
|
||
"t2v-14B": {
|
||
"prompt":
|
||
"Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
|
||
},
|
||
"t2i-14B": {
|
||
"prompt": "一个朴素端庄的美人",
|
||
},
|
||
"i2v-14B": {
|
||
"prompt":
|
||
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
|
||
"image":
|
||
"examples/i2v_input.JPG",
|
||
},
|
||
"flf2v-14B": {
|
||
"prompt":
|
||
"CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。",
|
||
"first_frame":
|
||
"examples/flf2v_input_first_frame.png",
|
||
"last_frame":
|
||
"examples/flf2v_input_last_frame.png",
|
||
},
|
||
"vace-1.3B": {
|
||
"src_ref_images":
|
||
'examples/girl.png,examples/snake.png',
|
||
"prompt":
|
||
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||
},
|
||
"vace-14B": {
|
||
"src_ref_images":
|
||
'examples/girl.png,examples/snake.png',
|
||
"prompt":
|
||
"在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。"
|
||
}
|
||
}
|
||
|
||
|
||
class WanArgs:
|
||
def __init__(self) -> None:
|
||
self.task ='t2v-1.3B'
|
||
self.size ='480*832'
|
||
self.frame_num =81
|
||
self.ckpt_dir =''
|
||
self.offload_model =None
|
||
self.ulysses_size =1
|
||
self.ring_size =1
|
||
self.t5_fsdp =False
|
||
self.t5_cpu =False
|
||
self.dit_fsdp =False
|
||
self.save_file =None
|
||
self.src_video =None
|
||
self.src_mask =None
|
||
self.src_ref_images =None
|
||
self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.'
|
||
self.use_prompt_extend =False
|
||
self.prompt_extend_method ='local_qwen'
|
||
self.prompt_extend_model =None
|
||
self.prompt_extend_target_lang ='zh'
|
||
self.image =None
|
||
self.first_frame =None
|
||
self.last_frame =None
|
||
self.sample_solver ='unipc'
|
||
self.sample_steps =50
|
||
self.sample_shift =5.0
|
||
self.sample_guide_scale =5.0
|
||
self.base_seed = 0
|
||
|
||
args = WanArgs()
|
||
|
||
|
||
|
||
def _init_logging(rank):
|
||
# logging
|
||
if rank == 0:
|
||
# set format
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="[%(asctime)s] %(levelname)s: %(message)s",
|
||
handlers=[logging.StreamHandler(stream=sys.stdout)])
|
||
else:
|
||
logging.basicConfig(level=logging.ERROR)
|
||
|
||
|
||
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
|
||
rank = int(os.getenv("RANK", 0))
|
||
world_size = int(os.getenv("WORLD_SIZE", 1))
|
||
local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||
device_id = local_rank
|
||
_init_logging(rank)
|
||
|
||
args.ckpt_dir = model_path
|
||
|
||
if args.offload_model is None:
|
||
args.offload_model = False if world_size > 1 else True
|
||
logging.info(
|
||
f"offload_model is not specified, set to {args.offload_model}.")
|
||
if world_size > 1:
|
||
torch.cuda.set_device(device_id)
|
||
dist.init_process_group(
|
||
backend="nccl",
|
||
init_method="env://",
|
||
rank=rank,
|
||
world_size=world_size)
|
||
else:
|
||
assert not (
|
||
args.t5_fsdp or args.dit_fsdp
|
||
), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments."
|
||
assert not (
|
||
args.ulysses_size > 1 or args.ring_size > 1
|
||
), f"context parallel are not supported in non-distributed environments."
|
||
cfg = WAN_CONFIGS[args.task]
|
||
logging.info(f"Generation job args: {args}")
|
||
logging.info(f"Generation model config: {cfg}")
|
||
|
||
if dist.is_initialized():
|
||
base_seed = [args.base_seed] if rank == 0 else [None]
|
||
dist.broadcast_object_list(base_seed, src=0)
|
||
args.base_seed = base_seed[0]
|
||
|
||
# logging.info(f"Input prompt: {args.prompt}")
|
||
logging.info("Creating WanT2V pipeline.")
|
||
wan_t2v = wan.WanT2V(
|
||
config=cfg,
|
||
checkpoint_dir=args.ckpt_dir,
|
||
device_id=device_id,
|
||
rank=rank,
|
||
t5_fsdp=args.t5_fsdp,
|
||
dit_fsdp=args.dit_fsdp,
|
||
use_usp=(args.ulysses_size > 1 or args.ring_size > 1),
|
||
t5_cpu=args.t5_cpu,
|
||
)
|
||
|
||
return wan_t2v
|
||
|
||
def safe_stem(text: str, maxlen: int = 60) -> str:
|
||
"""将提示词转为安全的文件名片段。"""
|
||
text = re.sub(r"\s+", "_", text.strip())
|
||
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
|
||
return (text[:maxlen] or "image").strip("_")
|
||
|
||
def generate_one(pipe, cfg, out_dir: Path, index: int):
|
||
"""
|
||
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
|
||
支持字段:
|
||
- prompt (必需)
|
||
"""
|
||
|
||
global args
|
||
|
||
prompt = cfg["prompt"]
|
||
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
|
||
stem = safe_stem(prompt)
|
||
filename = f"{index:03d}_{stem}_{stamp}.mp4"
|
||
out_path = out_dir / filename
|
||
|
||
start = time.time()
|
||
video = pipe.generate(
|
||
prompt,
|
||
size=SIZE_CONFIGS[args.size],
|
||
frame_num=args.frame_num,
|
||
shift=args.sample_shift,
|
||
sample_solver=args.sample_solver,
|
||
sampling_steps=args.sample_steps,
|
||
guide_scale=args.sample_guide_scale,
|
||
seed=args.base_seed,
|
||
offload_model=args.offload_model)
|
||
|
||
elapsed = time.time() - start
|
||
wan_cfg = WAN_CONFIGS[args.task]
|
||
cache_video(
|
||
tensor=video[None],
|
||
save_file=out_path,
|
||
fps=wan_cfg.sample_fps,
|
||
nrow=1,
|
||
normalize=True,
|
||
value_range=(-1, 1))
|
||
|
||
detail = {
|
||
"index": index,
|
||
"filename": filename,
|
||
"elapsed_seconds": round(elapsed, 6),
|
||
"prompt": prompt
|
||
}
|
||
return out_path, elapsed, detail
|
||
|