support patch ascend
This commit is contained in:
18
main.py
18
main.py
@@ -8,21 +8,11 @@ import re
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from functools import wraps
|
||||
|
||||
_orig_load = torch.load
|
||||
|
||||
@wraps(_orig_load)
|
||||
def _load_patch(*args, **kwargs):
|
||||
kwargs.setdefault("weights_only", False)
|
||||
return _orig_load(*args, **kwargs)
|
||||
|
||||
torch.load = _load_patch
|
||||
import patch
|
||||
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.outputs import OutputKeys
|
||||
import torch
|
||||
|
||||
|
||||
def safe_stem(text: str, maxlen: int = 60) -> str:
|
||||
@@ -58,7 +48,7 @@ def load_prompts(json_path: Path):
|
||||
|
||||
|
||||
def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16):
|
||||
pipe = pipeline('text-to-video-synthesis', model_path)
|
||||
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
|
||||
return pipe
|
||||
|
||||
|
||||
@@ -95,7 +85,7 @@ def main():
|
||||
parser.add_argument("--json", required=True, help="测试文本 JSON 文件路径")
|
||||
parser.add_argument("--results", required=True, help="结果 JSON 文件输出路径(*.json)")
|
||||
parser.add_argument("--outdir", required=True, help="图片输出目录")
|
||||
parser.add_argument("--device", default="cuda", choices=["cuda", "cpu"], help="推理设备")
|
||||
parser.add_argument("--device", default="cuda", help="推理设备")
|
||||
parser.add_argument("--dtype", default="fp16", choices=["fp16", "fp32"], help="推理精度")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user