support patch ascend

This commit is contained in:
2025-09-05 12:03:13 +08:00
parent ea0db79ebe
commit db19b64849
7 changed files with 77 additions and 29 deletions

18
main.py
View File

@@ -8,21 +8,11 @@ import re
import time
from datetime import datetime
from pathlib import Path
import torch
from functools import wraps
_orig_load = torch.load
@wraps(_orig_load)
def _load_patch(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_load(*args, **kwargs)
torch.load = _load_patch
import patch
from modelscope.pipelines import pipeline
from modelscope.outputs import OutputKeys
import torch
def safe_stem(text: str, maxlen: int = 60) -> str:
@@ -58,7 +48,7 @@ def load_prompts(json_path: Path):
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
pipe = pipeline('text-to-video-synthesis', model_path)
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
return pipe
@@ -95,7 +85,7 @@ def main():
parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径")
parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json")
parser.add_argument("--outdir", required=True, help="图片输出目录")
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
parser.add_argument("--device", default="cuda", help="推理设备")
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
args, _ = parser.parse_known_args()