[tools] add fp8 max/min constant in utils (#3959)
This commit is contained in:
committed by
GitHub
parent
ccdd10c84b
commit
18c27131f5
@@ -14,6 +14,7 @@
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import builtins
|
||||
import ctypes
|
||||
import dataclasses
|
||||
import io
|
||||
@@ -72,12 +73,25 @@ logger = logging.getLogger(__name__)
|
||||
show_time_cost = False
|
||||
time_infos = {}
|
||||
|
||||
HIP_FP8_E4M3_FNUZ_MAX = 224.0
|
||||
|
||||
|
||||
# https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
if is_hip():
|
||||
FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX
|
||||
else:
|
||||
FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
FP8_E4M3_MIN = -FP8_E4M3_MAX
|
||||
|
||||
builtins.FP8_E4M3_MAX = FP8_E4M3_MAX
|
||||
builtins.FP8_E4M3_MIN = FP8_E4M3_MIN
|
||||
|
||||
|
||||
def is_rocm() -> bool:
|
||||
return torch.cuda.is_available() and torch.version.hip
|
||||
|
||||
|
||||
Reference in New Issue
Block a user