From e483ab6d20740c81d13c89b0a5282d632ee2a1a4 Mon Sep 17 00:00:00 2001 From: Enrique Shockwave <33002121+qeternity@users.noreply.github.com> Date: Tue, 19 Aug 2025 02:53:15 +0100 Subject: [PATCH] enable marlin fp8 blockwise (#8990) --- python/sglang/srt/layers/quantization/fp8.py | 171 +++++++++--------- .../srt/layers/quantization/fp8_utils.py | 9 + 2 files changed, 94 insertions(+), 86 deletions(-) diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index f2e07b515..5c40bd1f0 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -49,6 +49,7 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, + can_auto_enable_marlin_fp8, cutlass_fp8_supported, dispatch_w8a8_block_fp8_linear, input_to_float8, @@ -209,17 +210,13 @@ class Fp8LinearMethod(LinearMethodBase): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = ( - get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE - ) - # Disable marlin for ROCm - if _is_hip: - self.use_marlin = False + self.use_marlin = False + if _is_cuda and MARLIN_FP8_AVAILABLE: + force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + auto_enable = can_auto_enable_marlin_fp8() + self.use_marlin = force_marlin or auto_enable self.block_quant = self.quant_config.weight_block_size is not None - if self.block_quant: - # Marlin doesn't support block-wise fp8 - self.use_marlin = False self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() @@ -332,7 +329,6 @@ class Fp8LinearMethod(LinearMethodBase): layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: - # Block quant doesn't need to process weights after loading if self.block_quant: # If ROCm, normalize the weights and scales to e4m3fnuz if _is_fp8_fnuz: @@ -342,7 +338,6 @@ class Fp8LinearMethod(LinearMethodBase): weight_scale=layer.weight_scale_inv, input_scale=None, ) - layer.input_scale = None elif _is_cpu: assert ( @@ -352,90 +347,94 @@ class Fp8LinearMethod(LinearMethodBase): return else: weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data - layer.weight = torch.nn.Parameter(weight, requires_grad=False) - layer.weight_scale_inv = torch.nn.Parameter( - weight_scale, requires_grad=False - ) - return - - layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) - - # If checkpoint not serialized fp8, quantize the weights. - if not self.quant_config.is_checkpoint_fp8_serialized: - if self.cutlass_fp8_supported or self.use_marlin: - # apply per-channel quantization default, as cutlass sgl-kernel and marlin only support per-channel scale - qweight, weight_scale = per_token_group_quant_fp8( - layer.weight, layer.weight.shape[-1] - ) - weight_scale = weight_scale.t().contiguous() - else: - # per-tensor quantization - qweight, weight_scale = input_to_float8(layer.weight) - - # Update the layer with the new values. - layer.weight = Parameter(qweight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - layer.input_scale = None - - # If checkpoint is fp8, handle that there are N scales for N - # shards in a fused module + layer.weight = Parameter(weight, requires_grad=False) + layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False) else: - layer.weight_scale = torch.nn.Parameter( - layer.weight_scale.data, requires_grad=False - ) - if ( - hasattr(self.quant_config, "activation_scheme") - and self.quant_config.activation_scheme == "static" - ) or ( - hasattr(self.quant_config, "linear_activation_scheme") - and self.quant_config.linear_activation_scheme == "static" - ): - layer.input_scale = torch.nn.Parameter( - layer.input_scale.data, requires_grad=False - ) + layer.weight = Parameter(layer.weight.data, requires_grad=False) - # cutlass sgl-kernel and marlin only support per-channel scale - if self.cutlass_fp8_supported or self.use_marlin: - weight = layer.weight - weight_scale = convert_to_channelwise( - layer.weight_scale, layer.logical_widths - ) + # If checkpoint not serialized fp8, quantize the weights. + if not self.quant_config.is_checkpoint_fp8_serialized: + if self.cutlass_fp8_supported or self.use_marlin: + # apply per-channel quantization default as + # cutlass sgl-kernel and marlin only support per-channel scale + qweight, weight_scale = per_token_group_quant_fp8( + layer.weight, layer.weight.shape[-1] + ) + weight_scale = weight_scale.t().contiguous() + else: + # per-tensor quantization + qweight, weight_scale = input_to_float8(layer.weight) + + # Update the layer with the new values. + layer.weight = Parameter(qweight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + layer.input_scale = None + + # If checkpoint is fp8, handle that there are N scales for N + # shards in a fused module else: - # Dequant -> Quant with max scale so we can run per tensor. - weight = layer.weight - weight_scale = layer.weight_scale - # If ROCm, normalize the weights and scales to e4m3fnuz - if _is_fp8_fnuz: - weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + layer.weight_scale = Parameter( + layer.weight_scale.data, requires_grad=False + ) + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): + layer.input_scale = Parameter( + layer.input_scale.data, requires_grad=False + ) + + # cutlass sgl-kernel and marlin only support per-channel scale + if self.cutlass_fp8_supported or self.use_marlin: + weight = layer.weight + weight_scale = convert_to_channelwise( + layer.weight_scale, layer.logical_widths + ) + else: + # Dequant -> Quant with max scale so we can run per tensor. + weight = layer.weight + weight_scale = layer.weight_scale + # If ROCm, normalize the weights and scales to e4m3fnuz + if _is_fp8_fnuz: + weight, weight_scale, input_scale = ( + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=weight_scale, + input_scale=layer.input_scale, + ) + ) + if input_scale is not None: + layer.input_scale = Parameter( + input_scale, requires_grad=False + ) + + weight_scale, weight = requantize_with_max_scale( weight=weight, weight_scale=weight_scale, - input_scale=layer.input_scale, + logical_widths=layer.logical_widths, ) - if input_scale is not None: - layer.input_scale = Parameter(input_scale, requires_grad=False) - weight_scale, weight = requantize_with_max_scale( - weight=weight, - weight_scale=weight_scale, - logical_widths=layer.logical_widths, - ) - - # Update layer with new values. - layer.weight = Parameter(weight.t(), requires_grad=False) - layer.weight_scale = Parameter(weight_scale, requires_grad=False) - if ( - hasattr(self.quant_config, "activation_scheme") - and self.quant_config.activation_scheme == "static" - ) or ( - hasattr(self.quant_config, "linear_activation_scheme") - and self.quant_config.linear_activation_scheme == "static" - ): - layer.input_scale = Parameter( - layer.input_scale.max(), requires_grad=False - ) + # Update layer with new values. + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + if ( + hasattr(self.quant_config, "activation_scheme") + and self.quant_config.activation_scheme == "static" + ) or ( + hasattr(self.quant_config, "linear_activation_scheme") + and self.quant_config.linear_activation_scheme == "static" + ): + layer.input_scale = Parameter( + layer.input_scale.max(), requires_grad=False + ) if self.use_marlin: - prepare_fp8_layer_for_marlin(layer) + if self.block_quant: + layer.weight_block_size = self.quant_config.weight_block_size + prepare_fp8_layer_for_marlin(layer, not self.block_quant) # Activations not quantized for marlin. del layer.input_scale diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 259d0098b..f051bd733 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -789,3 +789,12 @@ def apply_fp8_linear( bias, input.dtype, ) + + +def can_auto_enable_marlin_fp8() -> bool: + try: + major, minor = get_device_capability() + sm = major * 10 + minor + return 80 <= sm < 89 + except Exception: + return False