from typing import List import torch from torch_vacc._vacc_libs import _torch_vacc from .grad_scaler import OptState, GradScaler from .autocast_mode import autocast, custom_fwd, custom_bwd def get_amp_supported_dtype() -> List[torch.dtype]: return [torch.float16, torch.bfloat16] def is_autocast_enabled() -> bool: return _torch_vacc.is_autocast_enabled() def set_autocast_enabled(enable: bool): _torch_vacc.set_autocast_enabled(enable) def get_autocast_dtype() -> torch.dtype: return _torch_vacc.get_autocast_dtype() def set_autocast_dtype(dtype: torch.dtype): return _torch_vacc.set_autocast_dtype(dtype)