# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. import argparse import logging import os from pathlib import Path import sys import warnings from datetime import datetime import time import re warnings.filterwarnings('ignore') import torch import torch.distributed as dist from PIL import Image import wan from wan.configs import MAX_AREA_CONFIGS, SIZE_CONFIGS, SUPPORTED_SIZES, WAN_CONFIGS from wan.utils.prompt_extend import DashScopePromptExpander, QwenPromptExpander from wan.utils.utils import cache_image, cache_video, str2bool EXAMPLE_PROMPT = { "t2v-1.3B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2v-14B": { "prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.", }, "t2i-14B": { "prompt": "一个朴素端庄的美人", }, "i2v-14B": { "prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.", "image": "examples/i2v_input.JPG", }, "flf2v-14B": { "prompt": "CG动画风格,一只蓝色的小鸟从地面起飞,煽动翅膀。小鸟羽毛细腻,胸前有独特的花纹,背景是蓝天白云,阳光明媚。镜跟随小鸟向上移动,展现出小鸟飞翔的姿态和天空的广阔。近景,仰视视角。", "first_frame": "examples/flf2v_input_first_frame.png", "last_frame": "examples/flf2v_input_last_frame.png", }, "vace-1.3B": { "src_ref_images": 'examples/girl.png,examples/snake.png', "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" }, "vace-14B": { "src_ref_images": 'examples/girl.png,examples/snake.png', "prompt": "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" } } class WanArgs: def __init__(self) -> None: self.task ='t2v-1.3B' self.size ='480*832' self.frame_num =81 self.ckpt_dir ='' self.offload_model =None self.ulysses_size =1 self.ring_size =1 self.t5_fsdp =False self.t5_cpu =False self.dit_fsdp =False self.save_file =None self.src_video =None self.src_mask =None self.src_ref_images =None self.prompt ='Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.' self.use_prompt_extend =False self.prompt_extend_method ='local_qwen' self.prompt_extend_model =None self.prompt_extend_target_lang ='zh' self.image =None self.first_frame =None self.last_frame =None self.sample_solver ='unipc' self.sample_steps =50 self.sample_shift =5.0 self.sample_guide_scale =5.0 self.base_seed = 0 args = WanArgs() def _init_logging(rank): # logging if rank == 0: # set format logging.basicConfig( level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s", handlers=[logging.StreamHandler(stream=sys.stdout)]) else: logging.basicConfig(level=logging.ERROR) def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16): rank = int(os.getenv("RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) local_rank = int(os.getenv("LOCAL_RANK", 0)) device_id = local_rank _init_logging(rank) args.ckpt_dir = model_path if args.offload_model is None: args.offload_model = False if world_size > 1 else True logging.info( f"offload_model is not specified, set to {args.offload_model}.") if world_size > 1: torch.cuda.set_device(device_id) dist.init_process_group( backend="nccl", init_method="env://", rank=rank, world_size=world_size) else: assert not ( args.t5_fsdp or args.dit_fsdp ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." assert not ( args.ulysses_size > 1 or args.ring_size > 1 ), f"context parallel are not supported in non-distributed environments." cfg = WAN_CONFIGS[args.task] logging.info(f"Generation job args: {args}") logging.info(f"Generation model config: {cfg}") if dist.is_initialized(): base_seed = [args.base_seed] if rank == 0 else [None] dist.broadcast_object_list(base_seed, src=0) args.base_seed = base_seed[0] # logging.info(f"Input prompt: {args.prompt}") logging.info("Creating WanT2V pipeline.") wan_t2v = wan.WanT2V( config=cfg, checkpoint_dir=args.ckpt_dir, device_id=device_id, rank=rank, t5_fsdp=args.t5_fsdp, dit_fsdp=args.dit_fsdp, use_usp=(args.ulysses_size > 1 or args.ring_size > 1), t5_cpu=args.t5_cpu, ) return wan_t2v def safe_stem(text: str, maxlen: int = 60) -> str: """将提示词转为安全的文件名片段。""" text = re.sub(r"\s+", "_", text.strip()) text = re.sub(r"[^A-Za-z0-9_\-]+", "", text) return (text[:maxlen] or "image").strip("_") def generate_one(pipe, cfg, out_dir: Path, index: int): """ 依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数) 支持字段: - prompt (必需) """ global args prompt = cfg["prompt"] stamp = datetime.now().strftime("%Y%m%d-%H%M%S") stem = safe_stem(prompt) filename = f"{index:03d}_{stem}_{stamp}.mp4" out_path = out_dir / filename start = time.time() video = pipe.generate( prompt, size=SIZE_CONFIGS[args.size], frame_num=args.frame_num, shift=args.sample_shift, sample_solver=args.sample_solver, sampling_steps=args.sample_steps, guide_scale=args.sample_guide_scale, seed=args.base_seed, offload_model=args.offload_model) elapsed = time.time() - start wan_cfg = WAN_CONFIGS[args.task] cache_video( tensor=video[None], save_file=out_path, fps=wan_cfg.sample_fps, nrow=1, normalize=True, value_range=(-1, 1)) detail = { "index": index, "filename": filename, "elapsed_seconds": round(elapsed, 6), "prompt": prompt } return out_path, elapsed, detail