init
This commit is contained in:
26
torch_vacc/vacc/amp/__init__.py
Normal file
26
torch_vacc/vacc/amp/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from typing import List
|
||||
import torch
|
||||
from torch_vacc._vacc_libs import _torch_vacc
|
||||
|
||||
from .grad_scaler import OptState, GradScaler
|
||||
from .autocast_mode import autocast, custom_fwd, custom_bwd
|
||||
|
||||
|
||||
def get_amp_supported_dtype() -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
|
||||
def is_autocast_enabled() -> bool:
|
||||
return _torch_vacc.is_autocast_enabled()
|
||||
|
||||
|
||||
def set_autocast_enabled(enable: bool):
|
||||
_torch_vacc.set_autocast_enabled(enable)
|
||||
|
||||
|
||||
def get_autocast_dtype() -> torch.dtype:
|
||||
return _torch_vacc.get_autocast_dtype()
|
||||
|
||||
|
||||
def set_autocast_dtype(dtype: torch.dtype):
|
||||
return _torch_vacc.set_autocast_dtype(dtype)
|
||||
Reference in New Issue
Block a user