From 287427e2e66aef4e4d857cfd666fe849e9f73617 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Mon, 6 Jan 2025 14:54:52 -0800 Subject: [PATCH] Enable Nvidia's ModelOpt fp8 quantized models (#2535) --- python/sglang/srt/layers/linear.py | 1 + python/sglang/srt/layers/modelopt_quant.py | 173 ++++++++++++++++++ .../srt/layers/quantization/__init__.py | 2 + .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server_args.py | 8 + 5 files changed, 185 insertions(+) create mode 100644 python/sglang/srt/layers/modelopt_quant.py diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index d5dfb8718..b828c0391 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -44,6 +44,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "MarlinLinearMethod", "GPTQLinearMethod", "QQQLinearMethod", + "ModelOptFp8LinearMethod", ] diff --git a/python/sglang/srt/layers/modelopt_quant.py b/python/sglang/srt/layers/modelopt_quant.py new file mode 100644 index 000000000..2c0887df2 --- /dev/null +++ b/python/sglang/srt/layers/modelopt_quant.py @@ -0,0 +1,173 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/modelopt.py + +import logging +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter +from vllm.model_executor.layers.linear import LinearBase +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, + cutlass_fp8_supported, + requantize_with_max_scale, +) +from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter + +from sglang.srt.layers.linear import LinearMethodBase +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) + +# Initialize logger for the module +logger = logging.getLogger(__name__) + +# Supported activation schemes for the current configuration +ACTIVATION_SCHEMES = ["static"] + + +class ModelOptFp8Config(QuantizationConfig): + """Configuration for ModelOpt FP8 quantization, including serialization and compatibility checks.""" + + def __init__(self, is_checkpoint_fp8_serialized: bool = False) -> None: + """ + Args: + is_checkpoint_fp8_serialized (bool): Indicates if the checkpoint uses serialized FP8 format. + """ + self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized + if is_checkpoint_fp8_serialized: + logger.warning( + "Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change." + ) + + @classmethod + def get_name(cls) -> str: + return "modelopt" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 89 # Minimum hardware capability (e.g., Hopper GPUs). + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp8Config": + quant_method = cls.get_from_keys(config, ["quantization"]).get("quant_algo") + + if "FP8" not in quant_method: + raise ValueError( + "ModelOpt only supports static FP8 quantization in SGLang. " + "Check the `hf_quant_config.json` file for your model's configuration." + ) + + return cls(is_checkpoint_fp8_serialized=True) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + return ModelOptFp8LinearMethod(self) if isinstance(layer, LinearBase) else None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp8LinearMethod(LinearMethodBase): + """Linear method for ModelOpt static FP8 quantization. + + Supports loading FP8 checkpoints with static weight and activation scales. + Future support may include dynamic scales. + + **Limitations**: + 1. Only supports per-tensor quantization due to `torch._scaled_mm` limitations. + 2. Only supports the `float8_e4m3fn` data type. + + Args: + quant_config (ModelOptFp8Config): The ModelOpt quantization configuration. + """ + + def __init__(self, quant_config: ModelOptFp8Config): + super().__init__() + self.quant_config = quant_config + self.cutlass_fp8_supported = cutlass_fp8_supported() + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + """Creates and registers weights, weight scales, and input scales for FP8 quantization.""" + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_fp8_serialized + else params_dtype + ) + + # Set layer attributes + layer.logical_widths = output_partition_sizes + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Register weight + layer.register_parameter( + "weight", + ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ), + ) + + if self.quant_config.is_checkpoint_fp8_serialized: + # Register weight and input scales + for scale_name in ["weight_scale", "input_scale"]: + layer.register_parameter( + scale_name, + PerTensorScaleParameter( + data=torch.full( + (len(output_partition_sizes),), + torch.finfo(torch.float32).min, + ), + weight_loader=weight_loader, + ), + ) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Requantizes weights after loading using the maximum scale.""" + max_w_scale, quantized_weight = requantize_with_max_scale( + layer.weight, layer.weight_scale, layer.logical_widths + ) + layer.weight = Parameter(quantized_weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Applies FP8 linear transformation.""" + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + ) diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index ae9319f28..2ff570ba1 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig @@ -32,6 +33,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, "fbgemm_fp8": FBGEMMFp8Config, "marlin": MarlinConfig, + "modelopt": ModelOptFp8Config, "gguf": GGUFConfig, "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 41905e272..72cc0f83a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -150,6 +150,7 @@ class ModelRunner: "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, + "modelopt_config": server_args.modelopt_config, } ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index b61b4b2dc..6950d00b3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -149,6 +149,7 @@ class ServerArgs: torch_compile_max_bs: int = 32 cuda_graph_max_bs: Optional[int] = None torchao_config: str = "" + modelopt_config: str = "" enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False @@ -361,6 +362,7 @@ class ServerArgs: "awq_marlin", "bitsandbytes", "gguf", + "modelopt", ], help="The quantization method.", ) @@ -808,6 +810,12 @@ class ServerArgs: default=ServerArgs.torchao_config, help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row", ) + parser.add_argument( + "--modelopt-config", + type=str, + default=ServerArgs.modelopt_config, + help="Optimize the model with nvidia-modelopt. Experimental feature. Current choices are: fp8", + ) parser.add_argument( "--enable-nan-detection", action="store_true",