support diffusers ms models

This commit is contained in:
2025-09-08 16:32:50 +08:00
parent 0ae0bd7904
commit 495c3fcd8a
6 changed files with 47 additions and 7 deletions

28
main.py
View File

@@ -13,6 +13,8 @@ import patch
from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
import torch
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
def safe_stem(text: str, maxlen: int = 60) -> str:
@@ -47,12 +49,19 @@ def load_prompts(json_path: Path):
return prompts
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16, model_type: str = "text-to-video-synthesis"):
if model_type == "text-to-video-synthesis":
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
elif model_type == "text-to-video-ms":
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload() # 省显存
pipe.enable_vae_slicing()
else:
raise ValueError(f"不支持的模型类型: {model_type}")
return pipe
def generate_one(pipe, cfg: dict, out_dir: Path, index: int):
def generate_one(pipe, cfg: dict, out_dir: Path, index: int, model_type: str = "text-to-video-synthesis"):
"""
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
支持字段:
@@ -65,7 +74,13 @@ def generate_one(pipe, cfg: dict, out_dir: Path, index: int):
out_path = out_dir / filename
start = time.time()
output_video_path = pipe({"text": prompt}, output_video=str(out_path))[OutputKeys.OUTPUT_VIDEO]
if model_type == "text-to-video-synthesis":
output_video_path = pipe({"text": prompt}, output_video=str(out_path))[OutputKeys.OUTPUT_VIDEO]
elif model_type == "text-to-video-ms":
frames = pipe(prompt, num_frames=16).frames[0]
export_to_video(frames, str(out_path))
else:
raise ValueError(f"不支持的模型类型: {model_type}")
elapsed = time.time() - start
detail = {
@@ -87,6 +102,7 @@ def main():
parser.add_argument("--outdir", required=True, help="图片输出目录")
parser.add_argument("--device", default="cuda", help="推理设备")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
parser.add_argument("--model_type", default="text-to-video-synthesis", choices=["text-to-video-synthesis", "text-to-video-ms"], help="模型类型")
args, _ = parser.parse_known_args()
model_path = args.model
@@ -103,12 +119,12 @@ def main():
if not prompts:
raise ValueError("测试列表为空。")
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype)
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype, model_type=args.model_type)
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i)
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i, model_type=args.model_type)
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)