Adding SGLang FP8 Utils (#2348)
This commit is contained in:
27
python/sglang/srt/layers/quantization/fp8_utils.py
Normal file
27
python/sglang/srt/layers/quantization/fp8_utils.py
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_e4m3fn_to_e4m3fnuz(
|
||||||
|
weight: torch.Tensor,
|
||||||
|
weight_scale: torch.Tensor,
|
||||||
|
input_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||||
|
assert weight.dtype == torch.float8_e4m3fn
|
||||||
|
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
||||||
|
# but NaN in e4m3fnuz. So here we set it to 0.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_as_int8 = weight.view(torch.int8)
|
||||||
|
ROCM_FP8_NAN_AS_INT = -128
|
||||||
|
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
||||||
|
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
||||||
|
|
||||||
|
# For the same bits representation, e4m3fnuz value is half of
|
||||||
|
# the e4m3fn value, so we should double the scaling factor to
|
||||||
|
# get the same dequantized value.
|
||||||
|
# https://onnx.ai/onnx/technical/float8.html
|
||||||
|
weight_scale = weight_scale * 2.0
|
||||||
|
if input_scale is not None:
|
||||||
|
input_scale = input_scale * 2.0
|
||||||
|
return weight, weight_scale, input_scale
|
||||||
Reference in New Issue
Block a user