diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ee4ac8685..54909ac9d 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -14,6 +14,7 @@ """Common utilities.""" import base64 +import builtins import ctypes import dataclasses import io @@ -72,12 +73,25 @@ logger = logging.getLogger(__name__) show_time_cost = False time_infos = {} +HIP_FP8_E4M3_FNUZ_MAX = 224.0 + # https://pytorch.org/docs/stable/notes/hip.html#checking-for-hip def is_hip() -> bool: return torch.version.hip is not None +if is_hip(): + FP8_E4M3_MAX = HIP_FP8_E4M3_FNUZ_MAX +else: + FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max + +FP8_E4M3_MIN = -FP8_E4M3_MAX + +builtins.FP8_E4M3_MAX = FP8_E4M3_MAX +builtins.FP8_E4M3_MIN = FP8_E4M3_MIN + + def is_rocm() -> bool: return torch.cuda.is_available() and torch.version.hip