diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index d504b5ac4..8dcde41e8 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -557,7 +557,10 @@ def apply_fp8_linear( # We also don't pad when using torch.compile, # as it breaks with dynamic shapes. if pad_output is None: - pad_output = not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE") + pad_output = ( + not get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE") + and not cutlass_fp8_supported + ) output_padding = 17 if pad_output else None # View input as 2D matrix for fp8 methods diff --git a/python/sglang/srt/layers/quantization/fpgemm_fp8.py b/python/sglang/srt/layers/quantization/fpgemm_fp8.py new file mode 100644 index 000000000..fcfba7b09 --- /dev/null +++ b/python/sglang/srt/layers/quantization/fpgemm_fp8.py @@ -0,0 +1,199 @@ +# SPDX-License-Identifier: Apache-2.0 +from __future__ import annotations + +import logging +from typing import Any, Optional + +import torch +from torch.nn import Module +from torch.nn.parameter import Parameter + +from sglang.srt.layers.linear import LinearBase, LinearMethodBase +from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter +from sglang.srt.layers.quantization.base_config import ( + FusedMoEMethodBase, + LinearMethodBase, + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import ( + apply_fp8_linear, + can_auto_enable_marlin_fp8, + cutlass_fp8_supported, + normalize_e4m3fn_to_e4m3fnuz, +) +from sglang.srt.layers.quantization.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, +) +from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod +from sglang.srt.layers.quantization.utils import is_layer_skipped, replace_parameter +from sglang.srt.utils import get_bool_env_var, is_cuda, is_fp8_fnuz + +_is_cuda = is_cuda() +_is_fp8_fnuz = is_fp8_fnuz() + +logger = logging.getLogger(__name__) + + +class FBGEMMFp8Config(QuantizationConfig): + """Config class for FBGEMM Fp8.""" + + def __init__(self, ignore_list: list[str], input_scale_ub: float): + super().__init__() + self.ignore_list = ignore_list if ignore_list else [] + self.input_scale_ub = input_scale_ub + + # For GPUs that lack FP8 hardware suspport, we can leverage the Marlin + # kernel for fast weight-only FP8 quantization + # self.use_marlin = not marlin_fp8_supported() + self.use_marlin = False + if _is_cuda: + force_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + auto_enable = can_auto_enable_marlin_fp8() + self.use_marlin = force_marlin or auto_enable + + @classmethod + def get_name(cls) -> str: + return "fbgemm_fp8" + + @classmethod + def get_supported_act_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16, torch.float16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> list[str]: + return [] + + @classmethod + def from_config(cls, config: dict[str, Any]) -> FBGEMMFp8Config: + ignore_list = cls.get_from_keys(config, ["modules_to_not_convert"]) + input_scale_ub = cls.get_from_keys(config, ["activation_scale_ub"]) + return cls(ignore_list=ignore_list, input_scale_ub=input_scale_ub) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + if is_layer_skipped( + prefix=prefix, + ignored_layers=self.ignore_list, + fused_mapping=self.packed_modules_mapping, + ): + return UnquantizedLinearMethod() + return FBGEMMFp8LinearMethod(self) + return None + + +class FBGEMMFp8LinearMethod(LinearMethodBase): + + def __init__(self, quant_config: FBGEMMFp8Config): + self.quant_config = quant_config + # self.fp8_linear = Fp8LinearOp( + # act_quant_static=False, act_quant_group_shape=GroupShape.PER_TOKEN) + self.out_dtype = torch.get_default_dtype() + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: list[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # maybe_create_device_identity() + weight_loader = extra_weight_attrs.get("weight_loader") + del input_size, output_size + output_size_per_partition = sum(output_partition_sizes) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader, + ) + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE UPPER BOUND + input_scale_ub = torch.nn.Parameter( + torch.tensor((self.quant_config.input_scale_ub), dtype=torch.float32), + requires_grad=False, + ) + layer.input_scale_ub = input_scale_ub + + def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + + weight = layer.weight + + if _is_fp8_fnuz: + weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, weight_scale=layer.weight_scale, input_scale=None + ) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, requires_grad=False) + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + if self.quant_config.use_marlin: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale_ub + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + if self.quant_config.use_marlin: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=None, + input_scale_ub=layer.input_scale_ub, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=False, + ) diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index d76b900ae..e0b398c25 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -306,6 +306,13 @@ def marlin_permute_scales( return s +def marlin_permute_bias(s: torch.Tensor) -> torch.Tensor: + origin_shape = s.shape + _, scale_perm_single = get_scale_perms() + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + return s.reshape(*origin_shape).contiguous() + + def marlin_moe_permute_scales( s: torch.Tensor, size_k: int, diff --git a/python/sglang/srt/layers/quantization/marlin_utils_fp8.py b/python/sglang/srt/layers/quantization/marlin_utils_fp8.py new file mode 100644 index 000000000..94326d71e --- /dev/null +++ b/python/sglang/srt/layers/quantization/marlin_utils_fp8.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 + +import logging +from typing import Optional + +import torch + +from sglang.srt.layers.quantization.marlin_utils import ( + USE_FP32_REDUCE_DEFAULT, + marlin_make_workspace, + marlin_permute_bias, + marlin_permute_scales, + should_use_atomic_add_reduce, +) +from sglang.srt.layers.quantization.utils import get_scalar_types +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() +if _is_cuda: + from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack + +ScalarType, scalar_types = get_scalar_types() + +logger = logging.getLogger(__name__) + + +def fp8_fused_exponent_bias_into_scales(scales): + fp8_exponent = 4 + if scales.dtype == torch.half: + target_exponent = 5 + elif scales.dtype == torch.bfloat16: + target_exponent = 8 + # exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8 + # exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120 + exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1) + s = torch.ones_like(scales) * 2 + s = s**exponent_bias + return scales * s + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT, +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n,) + + use_atomic_add = should_use_atomic_add_reduce( + m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype + ) + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=weight, + b_bias=bias, + b_scales=weight_scale, + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=workspace, + b_q_type=scalar_types.float8_e4m3fn, + size_m=reshaped_x.size(0), + size_n=size_n, + size_k=size_k, + use_atomic_add=use_atomic_add, + use_fp32_reduce=use_fp32_reduce, + ) + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) + + if size_k_first: + assert layer.weight.shape == (part_size_k, part_size_n) + else: + assert layer.weight.shape == (part_size_n, part_size_k) + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(device) + + # WEIGHT + # Repack weights to marlin format + perm = torch.empty(0, dtype=torch.int, device=device) + qweight = pack_fp8_to_int32(layer.weight, size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, + perm=perm, + size_k=part_size_k, + size_n=part_size_n, + num_bits=8, + ) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + # Permute scales + if "weight_scale" in dir(layer): + scales = layer.weight_scale.to(layer.orig_dtype) + elif "weight_scale_inv" in dir(layer): + scales = layer.weight_scale_inv.to(layer.orig_dtype) + del layer.weight_scale_inv + + group_size = -1 if weight_block_size is None else weight_block_size[1] + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if weight_block_size is None: + if scales.nelement() == 1: + # tensor-wise quantization -> channel-wise quantization + # (1, 1) =>(repeat)=> (1, size_n) + scales = scales.view(1, 1).repeat_interleave(part_size_n, 1) + elif scales.nelement() > 1 and scales.nelement() != part_size_n: + assert part_size_n % scales.nelement() == 0 + s_size = scales.nelement() + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (1, s_size) =>(repeat)=> (1, size_n) + scales = scales.view(1, s_size) + scales = scales.repeat_interleave(part_size_n // s_size, 1) + else: + # channel-wise quantization + # (1, size_n) + scales = scales.view(1, part_size_n) + else: + # block-wise quantization -> group-wise quantization + # (size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.T.contiguous() + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 1) + # size_n may not divisible by block_size[0] + scales = scales[:, :part_size_n] + + marlin_scales = marlin_permute_scales( + s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size + ) + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + if hasattr(layer, "bias") and layer.bias is not None: + assert layer.bias.shape == (part_size_n,) + bias = marlin_permute_bias(layer.bias) + layer.bias = torch.nn.Parameter(bias, requires_grad=False) + + +def prepare_moe_fp8_layer_for_marlin( + layer: torch.nn.Module, size_k_first: bool = True +) -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads." + ) + + e = layer.num_experts + k = layer.hidden_size + n = layer.intermediate_size_per_partition + weight_block_size = getattr(layer, "weight_block_size", None) + + # WORKSPACE + device = layer.w13_weight.device + layer.workspace = marlin_make_workspace(device, 4) + perm = torch.empty(0, dtype=torch.int, device=device) + + # WEIGHT + # Repack weights to marlin format + for name in ["w13_weight", "w2_weight"]: + weight = getattr(layer, name) + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + if size_k_first: + assert weight.shape == (e, size_k, size_n) + else: + assert weight.shape == (e, size_n, size_k) + + for i in range(e): + qweight = pack_fp8_to_int32(weight[i], size_k_first) + if not size_k_first: + qweight = qweight.T.contiguous() + + marlin_qweight = gptq_marlin_repack( + b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8 + ) + tensor_list.append(marlin_qweight) + + weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + weight = torch.nn.Parameter(weight, requires_grad=False) + + setattr(layer, name, weight) + + # WEIGHT SCALES + # Permute scales + group_size = -1 if weight_block_size is None else weight_block_size[1] + + for name in ["w13", "w2"]: + if name + "_weight_scale" in dir(layer): + new_name = name + "_weight_scale" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + elif name + "_weight_scale_inv" in dir(layer): + new_name = name + "_weight_scale_inv" + scales = getattr(layer, new_name).to(layer.orig_dtype) + delattr(layer, new_name) + + tensor_list = [] + if "w13" in name: + size_n, size_k = n * 2, k + else: + size_n, size_k = k, n + + # marlin kernel only support channel-wise and group-wise quantization + # we need to convert the scales + if weight_block_size is None: + if scales.nelement() == e: + # tensor-wise quantization -> channel-wise quantization + # (e, 1, 1) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2) + elif scales.nelement() > e and scales.nelement() != e * size_n: + assert (e * size_n) % scales.nelement() == 0 + s_size = scales.nelement() // e + # tensor-wise quantization (for gate-up proj) + # -> channel-wise quantization + # (e, 1, s_size) =>(repeat)=> (e, 1, size_n) + scales = scales.view(e, 1, s_size) + scales = scales.repeat_interleave(size_n // s_size, 2) + else: + # channel-wise quantization + # (e, 1, size_n) + scales = scales.view(e, 1, size_n) + else: + # block-wise quantization -> group-wise quantization + # (e, size_k // block_size[1], ceil(size_n / block_size[0])) + # =>(repeat)=> (e, size_k // block_size[1], size_n) + if not size_k_first: + scales = scales.permute(0, 2, 1) + block_n = weight_block_size[0] + scales = scales.repeat_interleave(block_n, 2) + # size_n may not divisible by block_size[0] + scales = scales[..., :size_n].contiguous() + + for i in range(e): + marlin_scales = marlin_permute_scales( + s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size + ) + tensor_list.append(marlin_scales) + + scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + scales = fp8_fused_exponent_bias_into_scales(scales) + scales = torch.nn.Parameter(scales, requires_grad=False) + + setattr(layer, name + "_weight_scale", scales) + + # BIAS + # Permute bias + for name in ["w13_bias", "w2_bias"]: + if not hasattr(layer, name): + continue + bias = getattr(layer, name).to(layer.orig_dtype) + + tensor_list = [] + for i in range(e): + expert_bias = bias[i] + + tensor_list.append(marlin_permute_bias(expert_bias)) + + bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0) + bias = torch.nn.Parameter(bias, requires_grad=False) + setattr(layer, name, bias) + + +def pack_fp8_to_int32( + fp8_tensor: torch.Tensor, size_k_first: bool = True +) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.ndim == 2 + + fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor + fp8_tensor = fp8_tensor.contiguous() + # fp8_tensor is contiguous and have shape (N, K) now + # with `.view(torch.int32)`, it become (N, K // 4) + int32_tensor = fp8_tensor.view(torch.int32) + return int32_tensor.T.contiguous() if size_k_first else int32_tensor + + +def marlin_quant_fp8_torch(weight, group_size): + size_n, size_k = weight.shape + device = weight.device + + if group_size != -1: + scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(group_size, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + else: + scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448 + repeated_scales = scales.repeat_interleave(size_k, 1) + fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn) + weight_ref = fp8_weight.to(weight.dtype) * repeated_scales + + packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous() + marlin_qweight = gptq_marlin_repack( + b_q_weight=packed_weight, + perm=torch.empty(0, dtype=torch.int, device=device), + size_k=size_k, + size_n=size_n, + num_bits=8, + ) + + marlin_scales = marlin_permute_scales( + s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size + ) + + marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales) + + return weight_ref.T, marlin_qweight, marlin_scales