Files
enginex-mlu370-text2video/main.py

155 lines
5.5 KiB
Python
Raw Normal View History

2025-09-02 16:54:16 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import json
import os
import re
import time
from datetime import datetime
from pathlib import Path
2025-09-05 12:03:13 +08:00
import patch
2025-09-03 10:51:05 +08:00
2025-09-02 16:54:16 +08:00
from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
2025-09-05 12:03:13 +08:00
import torch
2025-09-08 16:32:50 +08:00
from diffusers import DiffusionPipeline
from diffusers.utils import export_to_video
2025-09-02 16:54:16 +08:00
def safe_stem(text: str, maxlen: int = 60) -> str:
"""将提示词转为安全的文件名片段。"""
text = re.sub(r"\s+", "_", text.strip())
text = re.sub(r"[^A-Za-z0-9_\-]+", "", text)
return (text[:maxlen] or "image").strip("_")
def load_prompts(json_path: Path):
"""
支持 JSON 结构
1) ["prompt 1", "prompt 2", ...]
"""
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
prompts = []
if isinstance(data, list):
if all(isinstance(x, str) for x in data):
for s in data:
prompts.append({"prompt": s})
elif all(isinstance(x, dict) for x in data):
for obj in data:
if "prompt" not in obj:
raise ValueError("每个对象都需要包含 'prompt' 字段")
prompts.append(obj)
else:
raise ValueError("JSON 列表元素需全为字符串或全为对象。")
else:
raise ValueError("JSON 顶层必须是列表。")
return prompts
2025-09-08 16:32:50 +08:00
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16, model_type: str = "text-to-video-synthesis"):
if model_type == "text-to-video-synthesis":
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
elif model_type == "text-to-video-ms":
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
pipe.enable_model_cpu_offload() # 省显存
pipe.enable_vae_slicing()
else:
raise ValueError(f"不支持的模型类型: {model_type}")
2025-09-02 16:54:16 +08:00
return pipe
2025-09-08 16:32:50 +08:00
def generate_one(pipe, cfg: dict, out_dir: Path, index: int, model_type: str = "text-to-video-synthesis"):
2025-09-02 16:54:16 +08:00
"""
依据 cfg 生成一张图并返回 (保存路径, 耗时秒, 详细参数)
支持字段
- prompt (必需)
"""
prompt = cfg["prompt"]
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
stem = safe_stem(prompt)
filename = f"{index:03d}_{stem}_{stamp}.mp4"
out_path = out_dir / filename
start = time.time()
2025-09-08 16:32:50 +08:00
if model_type == "text-to-video-synthesis":
output_video_path = pipe({"text": prompt}, output_video=str(out_path))[OutputKeys.OUTPUT_VIDEO]
elif model_type == "text-to-video-ms":
frames = pipe(prompt, num_frames=16).frames[0]
export_to_video(frames, str(out_path))
else:
raise ValueError(f"不支持的模型类型: {model_type}")
2025-09-02 16:54:16 +08:00
elapsed = time.time() - start
detail = {
"index": index,
"filename": filename,
"elapsed_seconds": round(elapsed, 6),
"prompt": prompt
}
return out_path, elapsed, detail
def main():
parser = argparse.ArgumentParser(
description="Stable Diffusion 基准与批量生成脚本JSON 结果)"
)
parser.add_argument("--model", required=True, help="模型路径或模型名(本地目录或 HF 仓库名)")
parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径")
parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json")
parser.add_argument("--outdir", required=True, help="图片输出目录")
2025-09-05 12:03:13 +08:00
parser.add_argument("--device", default="cuda", help="推理设备")
2025-09-02 16:54:16 +08:00
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
2025-09-08 16:32:50 +08:00
parser.add_argument("--model_type", default="text-to-video-synthesis", choices=["text-to-video-synthesis", "text-to-video-ms"], help="模型类型")
2025-09-03 16:04:18 +08:00
args, _ = parser.parse_known_args()
2025-09-02 16:54:16 +08:00
model_path = args.model
json_path = Path(args.json)
results_path = Path(args.results)
out_dir = Path(args.outdir)
out_dir.mkdir(parents=True, exist_ok=True)
results_path.parent.mkdir(parents=True, exist_ok=True)
dtype = torch.float16 if args.dtype == "fp16" else torch.float32
prompts = load_prompts(json_path)
if not prompts:
raise ValueError("测试列表为空。")
2025-09-08 16:32:50 +08:00
pipe = build_pipeline(model_path=model_path, device=args.device, dtype=dtype, model_type=args.model_type)
2025-09-02 16:54:16 +08:00
records = []
total_start = time.time()
for i, cfg in enumerate(prompts, 1):
2025-09-08 16:32:50 +08:00
out_path, elapsed, detail = generate_one(pipe, cfg, out_dir, i, model_type=args.model_type)
2025-09-02 16:54:16 +08:00
print(f"[{i}/{len(prompts)}] saved: {out_path.name} elapsed: {elapsed:.3f}s")
records.append(detail)
total_elapsed = round(time.time() - total_start, 6)
avg_latency = total_elapsed / len(records) if records else 0
# 结果 JSON 结构
result_obj = {
"timestamp": datetime.now().isoformat(timespec="seconds"),
"model": model_path,
"device": str(getattr(pipe, "device", "unknown")),
"dtype": "fp16" if dtype == torch.float16 else "fp32",
"count": len(records),
"total_elapsed_seconds": total_elapsed,
"avg_latency": avg_latency,
"cases": records
}
with open(results_path, "w", encoding="utf-8") as f:
json.dump(result_obj, f, ensure_ascii=False, indent=2)
print(f"\nAll done. vidoes: {len(records)}, total_elapsed: {total_elapsed:.3f}s, avg_latency: {avg_latency:.3f}")
print(f"Results JSON: {results_path}")
2025-09-09 12:56:07 +08:00
print(f"Output dir : {out_dir.resolve()}")
2025-09-02 16:54:16 +08:00
if __name__ == "__main__":
main()