Files
enginex-vastai-va16-vllm/torch_vacc/vacc/amp/__init__.py
2026-04-02 04:55:00 +00:00

27 lines
638 B
Python

from typing import List
import torch
from torch_vacc._vacc_libs import _torch_vacc
from .grad_scaler import OptState, GradScaler
from .autocast_mode import autocast, custom_fwd, custom_bwd
def get_amp_supported_dtype() -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
def is_autocast_enabled() -> bool:
return _torch_vacc.is_autocast_enabled()
def set_autocast_enabled(enable: bool):
_torch_vacc.set_autocast_enabled(enable)
def get_autocast_dtype() -> torch.dtype:
return _torch_vacc.get_autocast_dtype()
def set_autocast_dtype(dtype: torch.dtype):
return _torch_vacc.set_autocast_dtype(dtype)