fix ms model
This commit is contained in:
5
main.py
5
main.py
@@ -54,8 +54,9 @@ def build_pipeline(model_path: str, device: str = "cuda", dtype=torch.float16, m
|
||||
pipe = pipeline('text-to-video-synthesis', model_path, device=device)
|
||||
elif model_type == "text-to-video-ms":
|
||||
pipe = DiffusionPipeline.from_pretrained(model_path, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload() # 省显存
|
||||
pipe.enable_vae_slicing()
|
||||
pipe.to(device)
|
||||
# pipe.enable_model_cpu_offload() # 省显存
|
||||
# pipe.enable_vae_slicing()
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
return pipe
|
||||
|
||||
Reference in New Issue
Block a user