27 lines
638 B
Python
27 lines
638 B
Python
from typing import List
|
|
import torch
|
|
from torch_vacc._vacc_libs import _torch_vacc
|
|
|
|
from .grad_scaler import OptState, GradScaler
|
|
from .autocast_mode import autocast, custom_fwd, custom_bwd
|
|
|
|
|
|
def get_amp_supported_dtype() -> List[torch.dtype]:
|
|
return [torch.float16, torch.bfloat16]
|
|
|
|
|
|
def is_autocast_enabled() -> bool:
|
|
return _torch_vacc.is_autocast_enabled()
|
|
|
|
|
|
def set_autocast_enabled(enable: bool):
|
|
_torch_vacc.set_autocast_enabled(enable)
|
|
|
|
|
|
def get_autocast_dtype() -> torch.dtype:
|
|
return _torch_vacc.get_autocast_dtype()
|
|
|
|
|
|
def set_autocast_dtype(dtype: torch.dtype):
|
|
return _torch_vacc.set_autocast_dtype(dtype)
|