This commit is contained in:
2026-04-02 04:53:13 +00:00
parent 80932c96e5
commit 24df76db9d
1987 changed files with 447445 additions and 0 deletions

View File

@@ -0,0 +1,26 @@
from typing import List
import torch
from torch_vacc._vacc_libs import _torch_vacc
from .grad_scaler import OptState, GradScaler
from .autocast_mode import autocast, custom_fwd, custom_bwd
def get_amp_supported_dtype() -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
def is_autocast_enabled() -> bool:
return _torch_vacc.is_autocast_enabled()
def set_autocast_enabled(enable: bool):
_torch_vacc.set_autocast_enabled(enable)
def get_autocast_dtype() -> torch.dtype:
return _torch_vacc.get_autocast_dtype()
def set_autocast_dtype(dtype: torch.dtype):
return _torch_vacc.set_autocast_dtype(dtype)