diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index e23f089f1..5f4e2f0dd 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -279,6 +279,7 @@ class ModelConfig: "moe_wna16", ] compatible_quantization_methods = { + "modelopt_fp4": ["modelopt"], "w8a8_int8": ["compressed-tensors", "compressed_tensors"], "w8a8_fp8": ["compressed-tensors", "compressed_tensors"], } diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index c4627082f..3f9ca3dee 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -47,6 +47,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "GPTQLinearMethod", "FBGEMMFp8LinearMethod", "ModelOptFp8LinearMethod", + "ModelOptFp4LinearMethod", "IPEXAWQLinearMethod", ] diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 8bd410df6..885f9fe50 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -59,7 +59,10 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ) from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig -from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config +from sglang.srt.layers.quantization.modelopt_quant import ( + ModelOptFp4Config, + ModelOptFp8Config, +) from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -69,6 +72,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, "blockwise_int8": BlockInt8Config, "modelopt": ModelOptFp8Config, + "modelopt_fp4": ModelOptFp4Config, "w8a8_int8": W8A8Int8Config, "w8a8_fp8": W8A8Fp8Config, "moe_wna16": MoeWNA16Config, diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index eea9fa573..fc32e53f6 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -22,6 +22,10 @@ from sglang.srt.layers.quantization.utils import ( requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.utils import is_cuda_available + +if is_cuda_available(): + from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant # Initialize logger for the module logger = logging.getLogger(__name__) @@ -215,3 +219,245 @@ class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): def __init__(self, quant_config: ModelOptFp8Config): super().__init__(quant_config) + + +class ModelOptFp4Config(QuantizationConfig): + """Config class for FP4.""" + + def __init__( + self, + is_checkpoint_nvfp4_serialized: bool = False, + kv_cache_quant_algo: str = None, + group_size: int = None, + exclude_modules: List[str] = None, + ) -> None: + self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized + if is_checkpoint_nvfp4_serialized: + logger.warning( + "Detected nvfp4 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + self.group_size = group_size + self.kv_cache_quant_algo = kv_cache_quant_algo + self.exclude_modules = exclude_modules + + @classmethod + def get_name(cls) -> str: + return "modelopt_fp4" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half, torch.float8_e4m3fn] + + @classmethod + def get_min_capability(cls) -> int: + return 100 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["hf_quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "ModelOptFp4Config": + quant_config = cls.get_from_keys(config, ["quantization"]) + quant_method = quant_config["quant_algo"] + if not quant_method in ["FP8", "NVFP4"]: + raise ValueError( + f"ModelOpt currently only supports: FP8, NVFP4" + " quantizations in sglang. Please check the " + "`hf_quant_config.json` file for your model's " + "quant configuration." + ) + is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method + kv_cache_quant_algo = quant_config["kv_cache_quant_algo"] + group_size = quant_config["group_size"] + exclude_modules = quant_config["exclude_modules"] + if not (group_size and kv_cache_quant_algo and exclude_modules): + raise ValueError( + "NVFP4 quantization requires group size and " + "kv_cache_quant_algo specified in " + "hf_quant_config.json" + ) + return cls( + is_checkpoint_nvfp4_serialized, + kv_cache_quant_algo, + group_size, + exclude_modules, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + if self.exclude_modules and any( + module in prefix for module in self.exclude_modules + ): + return None + + if isinstance(layer, LinearBase): + return ModelOptFp4LinearMethod(self) + if self.kv_cache_quant_algo and isinstance(layer, RadixAttention): + return ModelOptFp8KVCacheMethod(self) + + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ModelOptFp4LinearMethod(LinearMethodBase): + """Linear method for NVFP4. + Supports loading NVFP4 checkpoints with the following structure: + + |Tensor Name | datatype | shape | + |----------------------------------------------------| + |input_scale | torch.float32 | scalar | + |weight | NVFP4(SE2M1) | [1, X, y/2] | + |weight_scale | FP8-E4M3 | [X, Y] | + |weight_scale_2 | torch.float32 | scalar | + + The weights are quantized per block of 16 elements. + Args: quant_config: The ModelOpt quantization config. + """ + + def __init__(self, quant_config: ModelOptFp4Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del input_size, output_size + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError( + "NVFP4 quantization was selected, " + " dynamic quantization is not supported." + ) + + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + if input_size_per_partition % 16 != 0: + raise ValueError( + "Unsupported model when in features size is " "not multiple of 16" + ) + + weight_dtype = ( + torch.float8_e4m3fn + if self.quant_config.is_checkpoint_nvfp4_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + # 2 fp4 data is packed in one uint8 in the input dimension + output_size_per_partition, + input_size_per_partition // 2, + dtype=torch.uint8, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + input_scale = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + + layer.register_parameter("input_scale", input_scale) + + weight_scale_2 = PerTensorScaleParameter( + data=torch.empty(len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader, + ) + layer.register_parameter("weight_scale_2", weight_scale_2) + + weight_scale = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, + input_size_per_partition // self.quant_config.group_size, + dtype=weight_dtype, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + + layer.register_parameter("weight_scale", weight_scale) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + input_scale_2 = layer.input_scale.max().to(torch.float32) + weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) + layer.input_scale = Parameter(input_scale_2, requires_grad=False) + layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) + layer.alpha = Parameter( + layer.input_scale * layer.weight_scale_2, requires_grad=False + ) + + # Pad and blockwise interleave weight_scale + scales = layer.weight_scale + scale_ndim = scales.ndim + if scale_ndim == 2: + scales = scales.unsqueeze(0) + assert scales.ndim == 3 + B, M, K = scales.shape + round_up_multiple = lambda x, m: (x + m - 1) // m * m + M_padded = round_up_multiple(M, 128) + K_padded = round_up_multiple(K, 4) + padded_scales = torch.zeros((B, M_padded, K_padded), dtype=scales.dtype) + padded_scales[:B, :M, :K] = scales + batches, rows, cols = padded_scales.shape + assert rows % 128 == 0 + assert cols % 4 == 0 + padded_scales = padded_scales.reshape(batches, rows // 128, 4, 32, cols // 4, 4) + padded_scales = padded_scales.permute((0, 1, 4, 3, 2, 5)) + padded_scales = padded_scales.contiguous().cuda() + padded_scales = ( + padded_scales.reshape(M, K) + if scale_ndim == 2 + else padded_scales.reshape(B, M, K) + ) + layer.weight_scale_interleaved = Parameter(padded_scales, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + output_dtype = x.dtype + x_m, _ = x.shape + w_n, _ = layer.weight.shape + output_shape = [x_m, w_n] + + # Quantize BF16 or FP16 to (FP4 and interleaved block scale) + x_fp4, x_scale_interleaved = scaled_fp4_quant(x, 1 / layer.input_scale) + + assert x_fp4.dtype == torch.uint8 + assert x_scale_interleaved.dtype == torch.float8_e4m3fn + assert layer.weight.dtype == torch.uint8 + assert layer.weight_scale_interleaved.dtype == torch.float8_e4m3fn + assert layer.alpha.dtype == torch.float32 + + out = cutlass_scaled_fp4_mm( + x_fp4, + layer.weight, + x_scale_interleaved, + layer.weight_scale_interleaved, + layer.alpha, + output_dtype, + ) + if bias is not None: + out = out + bias + return out.view(*output_shape) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 71b7910ee..a65c90de8 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -495,6 +495,7 @@ class ServerArgs: "bitsandbytes", "gguf", "modelopt", + "modelopt_fp4", "w8a8_int8", "w8a8_fp8", "moe_wna16", diff --git a/sgl-kernel/README.md b/sgl-kernel/README.md index 2afcfc3ce..d5f58c9c3 100644 --- a/sgl-kernel/README.md +++ b/sgl-kernel/README.md @@ -156,6 +156,14 @@ unset CCACHE_READONLY python -m uv build --wheel -Cbuild-dir=build --color=always . ``` +##### Configuring CMake Build Options +Cmake options can be configuring by adding `-Ccmake.define.