# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from fnmatch import fnmatch from typing import TYPE_CHECKING, Any import torch from torch.nn.parameter import Parameter import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.kernels.linear import ( init_fp8_linear_kernel, ) from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.oracle.fp8 import ( convert_to_fp8_moe_kernel_format, make_fp8_moe_kernel, make_fp8_moe_quant_config, select_fp8_moe_backend, ) from vllm.model_executor.layers.fused_moe.oracle.nvfp4 import ( convert_to_nvfp4_moe_kernel_format, is_global_sf_supported_for_nvfp4_backend, make_nvfp4_moe_kernel, make_nvfp4_moe_quant_config, select_nvfp4_moe_backend, ) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, process_fp8_input_tensor_strategy_moe, process_fp8_weight_tensor_strategy_moe, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) from vllm.model_executor.layers.quantization.utils.mxfp8_utils import ( MXFP8_BLOCK_SIZE, MXFP8_SCALE_DTYPE, MXFP8_VALUE_DTYPE, Mxfp8LinearBackend, Mxfp8LinearOp, swizzle_mxfp8_scale, ) from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( apply_nvfp4_linear, convert_to_nvfp4_linear_kernel_format, select_nvfp4_linear_backend, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped, kFp8DynamicTokenSym, kFp8StaticTensorSym, kFp8StaticTokenSym, kNvfp4Dynamic, kNvfp4Static, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_block_fp8_supported, requantize_with_max_scale, ) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from vllm.model_executor.utils import replace_parameter if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) QUANT_ALGOS = [ # FP8 (per-tensor weight + optional static activation scale). "FP8", # FP8 per-channel weight scale + per-token activation scale. "FP8_PER_CHANNEL_PER_TOKEN", # FP8 per-block weight-only (ModelOpt may emit this as lowercase). "FP8_PB_WO", # FP4 "NVFP4", # MXFP8 "MXFP8", # MIXED_PRECISION, "MIXED_PRECISION", ] KV_CACHE_QUANT_ALGOS = ["FP8"] class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: "ModelOptQuantConfigBase"): super().__init__(quant_config) class ModelOptQuantConfigBase(QuantizationConfig): LinearMethodCls: type = LinearMethodBase FusedMoEMethodCls: type = FusedMoEMethodBase KVCacheMethodCls: type = BaseKVCacheMethod def __init__( self, exclude_modules: list[str], ): super().__init__() self.exclude_modules: list[str] = exclude_modules def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and ModelOpt wildcard matching. The ModelOpt exclude_modules list is a list of wildcards. """ if len(self.exclude_modules) == 0: return False # First check exact matching with fused layer support if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # TODO: This special hard coded logic is not needed for quantized checkpoints # generated by ModelOpt >= 0.39.0 where they are handled natually by the # exclude_modules config. But need to keep them for loading quantized # checkpoints generated by older versions. Then check substring matching # for patterns not caught by exact match for exclude_module in self.exclude_modules: # Skip exact matches already handled above if exclude_module != prefix and ( exclude_module in prefix or ( prefix.startswith("language_model.") and exclude_module in prefix.removeprefix("language_model.") ) ): return True # modelopt exclude modules are not simple strings, they are wildcards for wildcard_pattern in self.exclude_modules: if fnmatch(prefix, wildcard_pattern): return True return False def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": # handle kv-cache first so we can focus only on weight quantization thereafter if isinstance(layer, (Attention, MLAAttention)): return self.KVCacheMethodCls(self) # handle exclusion if self.is_layer_excluded(prefix): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() return None # TODO: This special hard coded logic is not needed for quantized checkpoints # generated by ModelOpt >= 0.39.0 where they are handled natually by the # exclude_modules config. But need to keep them for loading quantized # checkpoints generated by older versions. Then check substring matching # for patterns not caught by exact match if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() # now, the layer is quantized, handle it here if isinstance(layer, LinearBase): quant_method = self.LinearMethodCls(self) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): quant_method = self.FusedMoEMethodCls( quant_config=self, moe_config=layer.moe_config ) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method return None def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if len(self.exclude_modules) > 0: # This is a workaround for the weights remapping issue: # https://github.com/vllm-project/vllm/issues/28072 # Right now, the Nvidia ModelOpt library use just one wildcard pattern: # module_path* # It gets applied if the whole tree of modules rooted at module_path # is not quantized. Here we replace such pattern by 2 patterns that are # collectively equivalent to the original pattern: # module_path # module_path.* new_exclude_modules = [] for exclude in self.exclude_modules: if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".": new_exclude_modules.append(exclude[:-1]) new_exclude_modules.append(exclude[:-1] + ".*") else: new_exclude_modules.append(exclude) self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) @staticmethod def _extract_modelopt_quant_algo( hf_quant_cfg: dict[str, Any] | None, ) -> str | None: """Extract upper-cased quant_algo from a modelopt config. Returns the quant_algo string (upper-cased), or None if the config is not a modelopt config. """ if hf_quant_cfg is None: return None if hf_quant_cfg.get("quant_method", "").lower() != "modelopt": return None if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): return str(quant_config.get("quant_algo", "")).upper() return None return str(hf_quant_cfg.get("quant_algo", "")).upper() @staticmethod def get_config_filenames() -> list[str]: return ["hf_quant_config.json"] @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], group_size: int | None, ) -> "ModelOptQuantConfigBase": raise NotImplementedError("Please implement this function in sub classes") @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase": # Handle both ModelOpt format and compressed-tensors style format if "quantization" in config: # Traditional ModelOpt format: # {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo") # Handle kv_cache_quant_algo with proper type validation kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") # Handle group_size with proper type validation group_size_raw = quant_config.get("group_size") # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) else: # Compressed-tensors style format (config.json quantization_config): # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo") # "kv_cache_scheme" (a dict) instead of "kv_cache_quant_algo" (a string). kv_cache_scheme = config.get("kv_cache_scheme") if isinstance(kv_cache_scheme, dict) and ( kv_cache_scheme.get("type") == "float" and kv_cache_scheme.get("num_bits") == 8 ): kv_cache_quant_method = "FP8" else: kv_cache_quant_method = None # "ignore" is the key in config.json exclude_modules = config.get("ignore", []) group_size_raw = config.get("group_size") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") # Normalize quant_algo for robust matching (ModelOpt may emit lowercase). quant_method = str(quant_method).upper() if kv_cache_quant_method is None: # No KV cache quantization, keep this branch just to have this comment pass elif not isinstance(kv_cache_quant_method, str): raise ValueError( f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_method)}" ) else: kv_cache_quant_method = kv_cache_quant_method.upper() if not isinstance(exclude_modules, list): raise ValueError( f"exclude_modules must be a list, got {type(exclude_modules)}" ) if group_size_raw is None: group_size = None elif isinstance(group_size_raw, int): group_size = group_size_raw else: try: group_size = int(group_size_raw) except (ValueError, TypeError): raise ValueError( f"group_size must be an integer, got {type(group_size_raw)}" ) from None if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration." ) return cls._from_config( quant_method=quant_method, kv_cache_quant_method=kv_cache_quant_method, exclude_modules=exclude_modules, group_size=group_size, original_config=config, ) class ModelOptFp8Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP8.""" def __init__( self, quant_method: str, is_checkpoint_fp8_serialized: bool, kv_cache_quant_method: str | None, exclude_modules: list[str], ) -> None: super().__init__(exclude_modules) self.quant_method = quant_method self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt fp8 checkpoint (quant_algo=%s). Please note " "that the format is experimental and could change.", quant_method, ) # Select LinearMethod implementation based on quant_algo. if self.quant_method == "FP8": self.LinearMethodCls = ModelOptFp8LinearMethod elif self.quant_method == "FP8_PER_CHANNEL_PER_TOKEN": self.LinearMethodCls = ModelOptFp8PcPtLinearMethod elif self.quant_method == "FP8_PB_WO": self.LinearMethodCls = ModelOptFp8PbWoLinearMethod else: raise ValueError( "Unsupported ModelOpt FP8 quant_algo for vLLM: " f"{self.quant_method}. Supported: FP8 / " "FP8_PER_CHANNEL_PER_TOKEN / FP8_PB_WO." ) def get_name(self) -> QuantizationMethods: return "modelopt" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and algo == "FP8": return "modelopt" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], **kwargs: Any, ) -> "ModelOptFp8Config": is_checkpoint_fp8_serialized = "FP8" in quant_method return cls( quant_method, is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules, ) class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8StaticTensorSym, weight_quant_key=kFp8StaticTensorSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: weight = layer.weight max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): max_w_scale, weight = requantize_with_max_scale( layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PcPtLinearMethod(LinearMethodBase): """Linear method for ModelOpt FP8_PER_CHANNEL_PER_TOKEN checkpoints. Expected checkpoint structure (per Linear): - weight: fp8-e4m3fn, shape [out, in] - weight_scale: fp32, shape [out] (per-output-channel) - no input_scale (activations are dynamically quantized per-token) """ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = init_fp8_linear_kernel( activation_quant_key=kFp8DynamicTokenSym, weight_quant_key=kFp8StaticTokenSym, out_dtype=torch.get_default_dtype(), module_name=self.__class__.__name__, ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "FP8_PER_CHANNEL_PER_TOKEN currently only supports " "FP8-serialized checkpoints." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) weight_scale = ChannelQuantScaleParameter( data=torch.empty(output_size_per_partition, dtype=torch.float32), output_dim=0, weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply_weights(layer, x, bias) class ModelOptFp8PbWoLinearMethod(LinearMethodBase): """Linear method for ModelOpt FP8_PB_WO checkpoints. ModelOpt exports `weight_scale` as a 4D tensor: [out_blk, 1, in_blk, 1] where block size is typically 128 for both dims. vLLM executes it as FP8 GEMM with *dynamic per-token* activation quant. """ _WEIGHT_BLOCK_SIZE: tuple[int, int] = (128, 128) def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config block_n, block_k = self._WEIGHT_BLOCK_SIZE self.weight_block_size = list(self._WEIGHT_BLOCK_SIZE) self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(block_n, block_k), act_quant_group_shape=GroupShape(1, block_k), cutlass_block_fp8_supported=cutlass_block_fp8_supported(), use_aiter_and_is_supported=False, ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "FP8_PB_WO currently only supports FP8-serialized checkpoints." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition # Expose block size so the v2 weight loaders can translate offsets from # element-space -> block-space for BlockQuantScaleParameter. layer.weight_block_size = self.weight_block_size weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=torch.float8_e4m3fn, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) block_n, block_k = self._WEIGHT_BLOCK_SIZE if output_size_per_partition % block_n != 0: raise ValueError( "ModelOpt FP8_PB_WO requires out_features divisible by " f"{block_n}, got {output_size_per_partition}." ) if input_size_per_partition % block_k != 0: raise ValueError( "ModelOpt FP8_PB_WO requires in_features divisible by " f"{block_k}, got {input_size_per_partition}." ) out_blks = output_size_per_partition // block_n in_blks = input_size_per_partition // block_k # Match ModelOpt's exported shape so weight loading works without a # custom loader: [out_blk, 1, in_blk, 1] weight_scale = BlockQuantScaleParameter( data=torch.empty((out_blks, 1, in_blks, 1), dtype=torch.float32), input_dim=2, output_dim=0, weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Keep weight in [out, in] layout for W8A8BlockFp8LinearOp. layer.weight = Parameter(layer.weight.data, requires_grad=False) scale = layer.weight_scale if scale.dim() == 4: # [out_blk, 1, in_blk, 1] -> [out_blk, in_blk] scale = scale.squeeze(1).squeeze(-1) elif scale.dim() != 2: raise ValueError( "Unexpected ModelOpt FP8_PB_WO weight_scale shape: " f"{tuple(scale.shape)}." ) layer.weight_scale = Parameter(scale.contiguous(), requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=None, bias=bias, ) class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config. """ def __init__( self, quant_config: ModelOptFp8Config, moe_config: FusedMoEConfig, ) -> None: super().__init__(moe_config) self.quant_config = quant_config assert self.quant_config.is_checkpoint_fp8_serialized # Select Fp8 MoE backend self.fp8_backend, self.experts_cls = select_fp8_moe_backend( config=self.moe, weight_key=kFp8StaticTensorSym, activation_key=kFp8StaticTensorSym, ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalizeModular, layer: torch.nn.Module, ) -> mk.FusedMoEExpertsModular: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): layer.orig_dtype = params_dtype layer.num_experts = num_experts # Use FP8 dtype if checkpoint is serialized weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight_loader = extra_weight_attrs.get("weight_loader") w13_num_shards = 2 if self.moe.is_act_and_mul else 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, w13_num_shards * intermediate_size_per_partition, hidden_size, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) # WEIGHT SCALES - Per-tensor scaling for ModelOpts # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. # For non-gated MoE, allocate 1 scale for w13. w13_weight_scale = PerTensorScaleParameter( data=torch.full( (num_experts, w13_num_shards), 1.0, dtype=torch.float32, ), weight_loader=weight_loader, ) w2_weight_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) # INPUT SCALES - Per-tensor scaling for ModelOpt w13_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) w2_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale) def _setup_kernel( self, layer: FusedMoE, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, w2_scale: torch.Tensor, w13_input_scale: torch.Tensor, w2_input_scale: torch.Tensor, ): w13, w2, w13_scale, w2_scale = convert_to_fp8_moe_kernel_format( fp8_backend=self.fp8_backend, layer=layer, w13=w13, w2=w2, w13_scale=w13_scale, w2_scale=w2_scale, w13_input_scale=w13_input_scale, w2_input_scale=w2_input_scale, ) # Replace parameters with updated versions. Note that this helper # function ensures the replacement is compatible with RL weight reloads. replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w2_weight", w2) replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.experts_cls is not None self.moe_kernel = make_fp8_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, fp8_backend=self.fp8_backend, experts_cls=self.experts_cls, routing_tables=layer._maybe_init_expert_routing_tables(), shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: w13 = layer.w13_weight w2 = layer.w2_weight w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale w13_input_scale = layer.w13_input_scale w2_input_scale = layer.w2_input_scale # Per tensor kernels require single activation scale. Use the max. w13_input_scale, w2_input_scale = process_fp8_input_tensor_strategy_moe( w13_input_scale, w2_input_scale ) replace_parameter(layer, "w13_input_scale", w13_input_scale) replace_parameter(layer, "w2_input_scale", w2_input_scale) # Per tensor kernels require single weight scale for w13 per expert, but # on disk there is a scale for w1 and w3. Use the max to requantize. shard_size = layer.intermediate_size_per_partition w13, w13_scale = process_fp8_weight_tensor_strategy_moe( w13, w13_scale, shard_size, num_experts=layer.w13_weight.shape[0], is_act_and_mul=self.moe.is_act_and_mul, ) # Shuffle weights to runtime format and setup kernel. self._setup_kernel( layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale ) def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale a1_scale = layer.w13_input_scale a2_scale = layer.w2_input_scale return make_fp8_moe_quant_config( fp8_backend=self.fp8_backend, w1_scale=w1_scale, w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, ) def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( x, layer.w13_weight, layer.w2_weight, router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, e_score_correction_bias=layer.e_score_correction_bias, routed_scaling_factor=layer.routed_scaling_factor, ) def apply( self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, shared_experts_input=shared_experts_input, ) ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod class ModelOptNvFp4Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP4.""" def __init__( self, is_checkpoint_nvfp4_serialized: bool, kv_cache_quant_algo: str | None, exclude_modules: list[str], group_size: int = 16, ) -> None: super().__init__(exclude_modules) self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( "Detected ModelOpt NVFP4 checkpoint. Please note that" " the format is experimental and could change in future." ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo def get_name(self) -> QuantizationMethods: return "modelopt_fp4" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: return 75 @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and ("NVFP4" in algo or "FP4" in algo): return "modelopt_fp4" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], group_size: int | None, **kwargs: Any, ) -> "ModelOptNvFp4Config": is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method if group_size is None: group_size = 16 # Default value # For FP4, these fields are required if is_checkpoint_nvfp4_serialized and "quantization" in original_config: # Check if required fields are present in the quantization config quant_config = original_config["quantization"] required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"NVFP4 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}" ) return cls( is_checkpoint_nvfp4_serialized, kv_cache_quant_method, exclude_modules, group_size, ) class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure: input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config self.marlin_input_dtype = None self.backend = select_nvfp4_linear_backend() def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( "NVFP4 quantization was selected, " " dynamic quantization is not supported." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % 16 != 0: raise ValueError( "Unsupported model when in features size is not multiple of 16" ) # The nvfp4 weight is still represented as weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_nvfp4_serialized else params_dtype ) # Weight weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension layer.output_size_per_partition, layer.input_size_per_partition // 2, dtype=torch.uint8, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) # Input Global Scale input_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("input_scale", input_global_scale) # Weight Global Scale weight_global_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("weight_scale_2", weight_global_scale) # Per Block Weight Scale weight_scale = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition // self.quant_config.group_size, dtype=weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Rename ModelOpt checkpoint names to standardized names input_global_scale = layer.input_scale.max().to(torch.float32) layer.input_global_scale = Parameter(input_global_scale, requires_grad=False) del layer.input_scale weight_global_scale = layer.weight_scale_2.max().to(torch.float32) layer.weight_global_scale = Parameter(weight_global_scale, requires_grad=False) del layer.weight_scale_2 # Pre-compute alpha and inverse for runtime quantization layer.alpha = Parameter( layer.input_global_scale * layer.weight_global_scale, requires_grad=False ) layer.input_global_scale_inv = Parameter( (1.0 / layer.input_global_scale).to(torch.float32), requires_grad=False ) # Convert layer to NVFP4 linear kernel format convert_to_nvfp4_linear_kernel_format(self.backend, layer) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return apply_nvfp4_linear( backend=self.backend, layer=layer, x=x, bias=bias, ) class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config """ def __init__( self, quant_config: ModelOptNvFp4Config, moe_config: FusedMoEConfig, ) -> None: super().__init__(moe_config) self.quant_config = quant_config # Select experts implementation. self.nvfp4_backend, self.experts_cls = select_nvfp4_moe_backend( config=self.moe, weight_key=kNvfp4Static, activation_key=kNvfp4Dynamic, ) self.use_global_sf = is_global_sf_supported_for_nvfp4_backend( self.nvfp4_backend ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalizeModular | None: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel initialization " "logic. This function should not be called." ) def uses_weight_scale_2_pattern(self) -> bool: """ FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. """ return True def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): assert self.quant_config.is_checkpoint_nvfp4_serialized layer.num_experts = num_experts layer.params_dtype = params_dtype layer.quant_config = self.quant_config weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") global_num_experts = extra_weight_attrs.get("global_num_experts") w13_num_shards = 2 if self.moe.is_act_and_mul else 1 # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) # GEMM 2 w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) w13_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, w13_num_shards * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} ) w13_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, w13_num_shards, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) global_sf_num_experts = ( global_num_experts if self.use_global_sf else num_experts ) w13_input_scale = PerTensorScaleParameter( data=torch.empty( global_sf_num_experts, w13_num_shards, dtype=torch.float32, ), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( data=torch.empty(global_sf_num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: FusedMoE) -> None: """ Convert NVFP4 MoE weights into kernel format and setup the kernel. """ # Use a single gscale for w13. if self.moe.is_act_and_mul and not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): logger.warning_once( "w1_weight_scale_2 must match w3_weight_scale_2. " "Accuracy may be affected." ) w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() ( w13, w13_scale, w13_scale_2, a13_scale, w2, w2_scale, w2_scale_2, a2_scale, ) = convert_to_nvfp4_moe_kernel_format( nvfp4_backend=self.nvfp4_backend, layer=layer, w13=layer.w13_weight, w13_scale=layer.w13_weight_scale, w13_scale_2=w13_weight_scale_2, a13_scale=layer.w13_input_scale, w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, w2_scale_2=layer.w2_weight_scale_2, a2_scale=layer.w2_input_scale, is_act_and_mul=self.moe.is_act_and_mul, ) replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w13_weight_scale_2", w13_scale_2) replace_parameter(layer, "w13_input_scale", a13_scale) replace_parameter(layer, "w2_weight", w2) replace_parameter(layer, "w2_weight_scale", w2_scale) replace_parameter(layer, "w2_weight_scale_2", w2_scale_2) replace_parameter(layer, "w2_input_scale", a2_scale) # Setup modular kernel. self.moe_quant_config = self.get_fused_moe_quant_config(layer) assert self.experts_cls is not None self.moe_kernel = make_nvfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, experts_cls=self.experts_cls, shared_experts=layer.shared_experts, routing_tables=layer._maybe_init_expert_routing_tables(), ) def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantConfig: return make_nvfp4_moe_quant_config( backend=self.nvfp4_backend, w13_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, w13_scale_2=layer.w13_weight_scale_2, w2_scale_2=layer.w2_weight_scale_2, a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, ) @property def supports_eplb(self) -> bool: return True def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( x, layer.w13_weight, layer.w2_weight, router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, e_score_correction_bias=layer.e_score_correction_bias, routed_scaling_factor=layer.routed_scaling_factor, ) def apply( self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, shared_experts_input=shared_experts_input, ) ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod class ModelOptMxFp8Config(ModelOptQuantConfigBase): """Config class for ModelOpt MXFP8.""" def __init__( self, is_checkpoint_mxfp8_serialized: bool, kv_cache_quant_algo: str | None, exclude_modules: list[str], ) -> None: super().__init__(exclude_modules) self.is_checkpoint_mxfp8_serialized = is_checkpoint_mxfp8_serialized if not is_checkpoint_mxfp8_serialized: raise ValueError( "MXFP8 quantization requires a serialized checkpoint. " "Dynamic quantization is not supported." ) logger.warning( "Detected ModelOpt MXFP8 checkpoint. Please note that " "the format is experimental and could change in future." ) self.kv_cache_quant_algo = kv_cache_quant_algo def get_name(self) -> QuantizationMethods: return "modelopt_mxfp8" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16] @classmethod def get_min_capability(cls) -> int: # MXFP8 hardware acceleration requires Blackwell (SM100) or newer return 100 def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": # MXFP8 does not yet support MoE models if isinstance(layer, FusedMoE): raise NotImplementedError( "MXFP8 quantization does not yet support MoE models. " "Please use FP8 or NVFP4 quantization for MoE models." ) return super().get_quant_method(layer, prefix) @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and "MXFP8" in algo: return "modelopt_mxfp8" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], **kwargs: Any, ) -> "ModelOptMxFp8Config": is_checkpoint_mxfp8_serialized = "MXFP8" in quant_method.upper() # For MXFP8, validate required fields in the config if is_checkpoint_mxfp8_serialized and "quantization" in original_config: quant_config = original_config["quantization"] required_fields = ["kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"MXFP8 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}" ) return cls( is_checkpoint_mxfp8_serialized, kv_cache_quant_method, exclude_modules, ) class ModelOptMxFp8LinearMethod(LinearMethodBase): """Linear method for ModelOpt MXFP8 quantization.""" def __init__(self, quant_config: ModelOptMxFp8Config) -> None: self.quant_config = quant_config if not self.quant_config.is_checkpoint_mxfp8_serialized: raise ValueError( "MXFP8 currently only supports serialized checkpoints. " "Dynamic quantization is not supported." ) self.backend: Mxfp8LinearBackend = Mxfp8LinearBackend.FLASHINFER_CUTLASS self.mxfp8_linear_op = Mxfp8LinearOp(backend=self.backend) logger.info_once("Using %s backend for MXFP8 GEMM", self.backend.value) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_mxfp8_serialized: raise ValueError( "MXFP8 quantization was selected, but checkpoint is not " "MXFP8 serialized. Dynamic quantization is not supported." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % MXFP8_BLOCK_SIZE != 0: raise ValueError( f"MXFP8 requires input dimension to be divisible by " f"{MXFP8_BLOCK_SIZE}, got {input_size_per_partition}" ) # Weight tensor: FP8 E4M3 format weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=MXFP8_VALUE_DTYPE, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) # Weight scale tensor (E8M0 encoded as uint8), one scale per block of 32 along K weight_scale = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition // MXFP8_BLOCK_SIZE, dtype=MXFP8_SCALE_DTYPE, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def _process_weights_after_loading_scale_2d(self, layer: torch.nn.Module) -> None: """Not swizzled - MXFP8 GEMM emulation""" weight = layer.weight.data # [N, K] N, K = weight.shape scale_k = K // MXFP8_BLOCK_SIZE # Slice weight_scale to match weight dimensions (handles padding) weight_scale = layer.weight_scale.data[:N, :scale_k].contiguous() layer.weight = Parameter(weight.contiguous(), requires_grad=False) layer.weight_scale = Parameter(weight_scale, requires_grad=False) def _process_weights_after_loading_scale_1d(self, layer: torch.nn.Module) -> None: """Swizzled - MXFP8 GEMM Flashinfer CUTLASS""" weight = layer.weight.data # [N, K] N, K = weight.shape # 2D weight scale weight_scale = layer.weight_scale.data # Swizzle the weight scales scale_k = K // MXFP8_BLOCK_SIZE weight_scale_2d = weight_scale[:N, :scale_k].contiguous() weight_scale_swizzled = swizzle_mxfp8_scale(weight_scale_2d, M=N, K=K) layer.weight = Parameter(weight.contiguous(), requires_grad=False) layer.weight_scale = Parameter( weight_scale_swizzled.contiguous(), requires_grad=False ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Validate weight tensor if layer.weight.ndim != 2: raise ValueError( f"MXFP8 weight must be 2D tensor [N, K], got {layer.weight.ndim}D " f"with shape {tuple(layer.weight.shape)}" ) if layer.weight.dtype != MXFP8_VALUE_DTYPE: raise ValueError( f"MXFP8 weight must be {MXFP8_VALUE_DTYPE} (FP8 E4M3), " f"got {layer.weight.dtype}. The checkpoint may not be properly " f"quantized with MXFP8." ) # Validate weight scale tensor (should be 2D, not swizzled) assert layer.weight_scale.ndim == 2, ( f"MXFP8 weight scale must be 2D, got {layer.weight_scale.ndim}D" ) assert layer.weight_scale.dtype == MXFP8_SCALE_DTYPE, ( f"MXFP8 weight scale must be {MXFP8_SCALE_DTYPE}," f" got {layer.weight_scale.dtype}" ) if self.backend == Mxfp8LinearBackend.EMULATION: # Swizzled layout is not used self._process_weights_after_loading_scale_2d(layer) return assert self.backend == Mxfp8LinearBackend.FLASHINFER_CUTLASS # Swizzled layout is required for Flashinfer CUTLASS self._process_weights_after_loading_scale_1d(layer) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: if layer.weight.dtype != MXFP8_VALUE_DTYPE: raise ValueError( f"Weight dtype {layer.weight.dtype} != expected {MXFP8_VALUE_DTYPE}" ) if layer.weight_scale.dtype != MXFP8_SCALE_DTYPE: raise ValueError( f"Weight scale dtype {layer.weight_scale.dtype} != " f"expected {MXFP8_SCALE_DTYPE}" ) return self.mxfp8_linear_op.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=x.dtype, bias=bias, ) # Register the method classes for ModelOptMxFp8Config ModelOptMxFp8Config.LinearMethodCls = ModelOptMxFp8LinearMethod ModelOptMxFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod class ModelOptMixedPrecisionConfig(ModelOptQuantConfigBase): """Config class for ModelOpt MIXED_PRECISION. Supports checkpoints where different layers use different quantization algorithms (e.g., FP8 for dense layers and NVFP4 for MoE experts). The per-layer algorithm is specified in the ``quantized_layers`` dict inside ``config.json``'s ``quantization_config`` (preferred) or the legacy ``hf_quant_config.json``. """ def __init__( self, kv_cache_quant_method: str | None, exclude_modules: list[str], quantized_layers: dict[str, dict[str, Any]], fp8_config: ModelOptFp8Config, nvfp4_config: ModelOptNvFp4Config, ) -> None: super().__init__(exclude_modules) self.kv_cache_quant_method = kv_cache_quant_method self.quantized_layers = quantized_layers self.fp8_config = fp8_config self.nvfp4_config = nvfp4_config def get_name(self) -> QuantizationMethods: return "modelopt_mixed" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: algo = cls._extract_modelopt_quant_algo(hf_quant_cfg) if algo is not None and algo == "MIXED_PRECISION": return "modelopt_mixed" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], group_size: int | None, **kwargs: Any, ) -> "ModelOptMixedPrecisionConfig": if "quantization" in original_config: quantized_layers = original_config["quantization"].get( "quantized_layers", {} ) else: quantized_layers = original_config.get("quantized_layers", {}) if not quantized_layers: raise ValueError( "MIXED_PRECISION quant_algo requires a non-empty " "'quantized_layers' mapping in the quantization config." ) # Determine group_size from the first NVFP4 entry if not provided. if group_size is None: for layer_info in quantized_layers.values(): if layer_info.get("quant_algo", "").upper() == "NVFP4": group_size = layer_info.get("group_size", 16) break if group_size is None: group_size = 16 fp8_config = ModelOptFp8Config( quant_method="FP8", is_checkpoint_fp8_serialized=True, kv_cache_quant_method=kv_cache_quant_method, exclude_modules=[], ) nvfp4_config = ModelOptNvFp4Config( is_checkpoint_nvfp4_serialized=True, kv_cache_quant_algo=kv_cache_quant_method, exclude_modules=[], group_size=group_size, ) return cls( kv_cache_quant_method=kv_cache_quant_method, exclude_modules=exclude_modules, quantized_layers=quantized_layers, fp8_config=fp8_config, nvfp4_config=nvfp4_config, ) def _resolve_quant_algo(self, prefix: str) -> str | None: """Look up the quant_algo for a vLLM-side layer prefix. Tries three strategies in order: 1. Direct lookup in ``quantized_layers``. 2. Packed/fused-layer lookup (unfuse via ``packed_modules_mapping``). 3. Prefix-based lookup for FusedMoE (any child key starts with ``prefix + "."``). Returns the upper-cased quant_algo string, or *None* if the prefix is not found. """ # 1. Direct lookup if prefix in self.quantized_layers: return self.quantized_layers[prefix]["quant_algo"].upper() # 2. Packed / fused layer lookup proj_name = prefix.rsplit(".", 1)[-1] if self.packed_modules_mapping and proj_name in self.packed_modules_mapping: algos: set[str] = set() base = prefix.rsplit(".", 1)[0] for shard_name in self.packed_modules_mapping[proj_name]: shard_prefix = f"{base}.{shard_name}" if shard_prefix in self.quantized_layers: algos.add(self.quantized_layers[shard_prefix]["quant_algo"].upper()) if len(algos) == 1: return algos.pop() if len(algos) > 1: raise ValueError( f"Mixed quant_algo within fused layer {prefix}: " f"{algos}. All shards must use the same quantization." ) # 3. Prefix-based lookup (for FusedMoE / parent modules) prefix_dot = prefix + "." for key, info in self.quantized_layers.items(): if key.startswith(prefix_dot): return info["quant_algo"].upper() return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": """Return quantize-method based on layer.""" # KV-cache quantization if isinstance(layer, Attention): if self.kv_cache_quant_method: return ModelOptFp8KVCacheMethod(self) return None # Excluded layers if self.is_layer_excluded(prefix): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() return None quant_algo = self._resolve_quant_algo(prefix) if isinstance(layer, LinearBase): if quant_algo == "FP8": return ModelOptFp8LinearMethod(self.fp8_config) if quant_algo == "NVFP4": return ModelOptNvFp4LinearMethod(self.nvfp4_config) # Layer not in quantized_layers — leave unquantized return UnquantizedLinearMethod() if isinstance(layer, FusedMoE): if quant_algo == "FP8": return ModelOptFp8MoEMethod( quant_config=self.fp8_config, moe_config=layer.moe_config, ) if quant_algo == "NVFP4": return ModelOptNvFp4FusedMoE( quant_config=self.nvfp4_config, moe_config=layer.moe_config, ) return None return None def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): super().apply_vllm_mapper(hf_to_vllm_mapper) if self.quantized_layers: self.quantized_layers = hf_to_vllm_mapper.apply_dict(self.quantized_layers)