From 63e845d0bb3a4095af9640242aaca4ed8656fed8 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Sat, 28 Sep 2024 12:27:54 -0700 Subject: [PATCH] Add float8 dynamic quant to torchao_utils (#1528) --- python/sglang/srt/layers/torchao_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 51a307e5c..46b082401 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): """ # Lazy import to suppress some warnings from torchao.quantization import ( + float8_dynamic_activation_float8_weight, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, quantize_, ) + from torchao.quantization.observer import PerRow, PerTensor dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) dummy_linear.weight = param @@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 quantize_(dummy_linear, float8_weight_only()) + elif "fp8dq" in torchao_config: + granularity = torchao_config.split("-")[-1] + GRANULARITY_MAP = { + "per_row": PerRow(), + "per_tensor": PerTensor(), + } + assert ( + granularity in GRANULARITY_MAP + ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" + quantize_( + dummy_linear, + float8_dynamic_activation_float8_weight( + granularity=GRANULARITY_MAP[granularity] + ), + ) + return dummy_linear.weight