diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py index 51a307e5c..46b082401 100644 --- a/python/sglang/srt/layers/torchao_utils.py +++ b/python/sglang/srt/layers/torchao_utils.py @@ -18,11 +18,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): """ # Lazy import to suppress some warnings from torchao.quantization import ( + float8_dynamic_activation_float8_weight, int4_weight_only, int8_dynamic_activation_int8_weight, int8_weight_only, quantize_, ) + from torchao.quantization.observer import PerRow, PerTensor dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) dummy_linear.weight = param @@ -45,6 +47,22 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str): # this requires newer hardware # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 quantize_(dummy_linear, float8_weight_only()) + elif "fp8dq" in torchao_config: + granularity = torchao_config.split("-")[-1] + GRANULARITY_MAP = { + "per_row": PerRow(), + "per_tensor": PerTensor(), + } + assert ( + granularity in GRANULARITY_MAP + ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}" + quantize_( + dummy_linear, + float8_dynamic_activation_float8_weight( + granularity=GRANULARITY_MAP[granularity] + ), + ) + return dummy_linear.weight