# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from fnmatch import fnmatch from typing import TYPE_CHECKING, Any, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant from vllm.attention.layer import Attention from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.linear import ( LinearBase, LinearMethodBase, UnquantizedLinearMethod, ) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( build_flashinfer_fp4_cutlass_moe_prepare_finalize, flashinfer_trtllm_fp4_moe, flashinfer_trtllm_fp4_routed_moe, prepare_static_weights_for_trtllm_fp4_moe, reorder_w1w3_to_w3w1, select_nvfp4_gemm_impl, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, select_cutlass_fp8_gemm_impl, swap_w13_to_w31, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( get_marlin_input_dtype, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( apply_fp4_marlin_linear, is_fp4_marlin_supported, prepare_fp4_layer_for_marlin, prepare_moe_fp4_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, cutlass_fp4_supported, is_layer_skipped, swizzle_blockscale, ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, requantize_with_max_scale, ) from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter from vllm.scalar_type import scalar_types from vllm.utils.flashinfer import ( flashinfer_scaled_fp4_mm, has_flashinfer, has_flashinfer_moe, ) from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) QUANT_ALGOS = ["FP8", "NVFP4"] KV_CACHE_QUANT_ALGOS = ["FP8"] class ModelOptFp8KVCacheMethod(BaseKVCacheMethod): """ Supports loading kv-cache scaling factors from FP8 checkpoints. """ def __init__(self, quant_config: "ModelOptQuantConfigBase"): super().__init__(quant_config) class ModelOptQuantConfigBase(QuantizationConfig): LinearMethodCls: type = LinearMethodBase FusedMoEMethodCls: type = FusedMoEMethodBase KVCacheMethodCls: type = BaseKVCacheMethod def __init__( self, exclude_modules: list[str], ): super().__init__() self.exclude_modules: list[str] = exclude_modules def is_layer_excluded(self, prefix: str) -> bool: """ Check if a layer should be excluded from quantization. Handles both exact matching (for fused layers) and ModelOpt wildcard matching. The ModelOpt exclude_modules list is a list of wildcards. """ if len(self.exclude_modules) == 0: return False # First check exact matching with fused layer support if is_layer_skipped(prefix, self.exclude_modules, self.packed_modules_mapping): return True # TODO: This special hard coded logic is not needed for quantized checkpoints # generated by ModelOpt >= 0.39.0 where they are handled natually by the # exclude_modules config. But need to keep them for loading quantized # checkpoints generated by older versions. Then check substring matching # for patterns not caught by exact match for exclude_module in self.exclude_modules: # Skip exact matches already handled above if exclude_module != prefix and ( exclude_module in prefix or ( prefix.startswith("language_model.") and exclude_module in prefix.removeprefix("language_model.") ) ): return True # modelopt exclude modules are not simple strings, they are wildcards for wildcard_pattern in self.exclude_modules: if fnmatch(prefix, wildcard_pattern): return True return False def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: # handle kv-cache first so we can focus only on weight quantization thereafter if isinstance(layer, Attention): return self.KVCacheMethodCls(self) # handle exclusion if self.is_layer_excluded(prefix): if isinstance(layer, LinearBase): return UnquantizedLinearMethod() return None # TODO: This special hard coded logic is not needed for quantized checkpoints # generated by ModelOpt >= 0.39.0 where they are handled natually by the # exclude_modules config. But need to keep them for loading quantized # checkpoints generated by older versions. Then check substring matching # for patterns not caught by exact match if "vision_tower" in prefix or "vision_model" in prefix: return UnquantizedLinearMethod() # now, the layer is quantized, handle it here if isinstance(layer, LinearBase): quant_method = self.LinearMethodCls(self) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method elif isinstance(layer, FusedMoE): quant_method = self.FusedMoEMethodCls(quant_config=self, layer=layer) if getattr(quant_method, "backend", "") == "marlin": quant_method.marlin_input_dtype = get_marlin_input_dtype(prefix) return quant_method return None def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if len(self.exclude_modules) > 0: # This is a workaround for the weights remapping issue: # https://github.com/vllm-project/vllm/issues/28072 # Right now, the Nvidia ModelOpt library use just one wildcard pattern: # module_path* # It gets applied if the whole tree of modules rooted at module_path # is not quantized. Here we replace such pattern by 2 patterns that are # collectively equivalent to the original pattern: # module_path # module_path.* new_exclude_modules = [] for exclude in self.exclude_modules: if len(exclude) >= 2 and exclude[-1] == "*" and exclude[-2] != ".": new_exclude_modules.append(exclude[:-1]) new_exclude_modules.append(exclude[:-1] + ".*") else: new_exclude_modules.append(exclude) self.exclude_modules = hf_to_vllm_mapper.apply_list(new_exclude_modules) @staticmethod def get_config_filenames() -> list[str]: return ["hf_quant_config.json"] @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], group_size: int | None, ) -> "ModelOptQuantConfigBase": raise NotImplementedError("Please implement this function in sub classes") @classmethod def from_config(cls, config: dict[str, Any]) -> "ModelOptQuantConfigBase": # Handle both ModelOpt format and compressed-tensors style format if "quantization" in config: # Traditional ModelOpt format: # {"quantization": {"quant_algo": "..."}} quant_config = cls.get_from_keys(config, ["quantization"]) if not isinstance(quant_config, dict): raise ValueError("Expected 'quantization' to be a dictionary in config") quant_method = quant_config.get("quant_algo") # Handle kv_cache_quant_algo with proper type validation kv_cache_quant_method = quant_config.get("kv_cache_quant_algo") # Handle group_size with proper type validation group_size_raw = quant_config.get("group_size") # "exclude_modules" is the key in the legacy hf_quant_config.json exclude_modules = quant_config.get("exclude_modules", []) else: # Compressed-tensors style format: # {"quant_algo": "...", "quant_method": "modelopt"} quant_method = config.get("quant_algo") kv_cache_quant_method = config.get("kv_cache_quant_algo") # "ignore" is the key in config.json exclude_modules = config.get("ignore", []) group_size_raw = config.get("group_size") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") if kv_cache_quant_method is None: # No KV cache quantization, keep this branch just to have this comment pass elif not isinstance(kv_cache_quant_method, str): raise ValueError( f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_method)}" ) if not isinstance(exclude_modules, list): raise ValueError( f"exclude_modules must be a list, got {type(exclude_modules)}" ) if group_size_raw is None: group_size = None elif isinstance(group_size_raw, int): group_size = group_size_raw else: try: group_size = int(group_size_raw) except (ValueError, TypeError): raise ValueError( f"group_size must be an integer, got {type(group_size_raw)}" ) from None if quant_method not in QUANT_ALGOS: raise ValueError( f"ModelOpt currently only supports: {QUANT_ALGOS} " "quantizations in vLLM. Please check the " "`hf_quant_config.json` file for your model's " "quant configuration." ) return cls._from_config( quant_method=quant_method, kv_cache_quant_method=kv_cache_quant_method, exclude_modules=exclude_modules, group_size=group_size, original_config=config, ) class ModelOptFp8Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP8.""" def __init__( self, is_checkpoint_fp8_serialized: bool, kv_cache_quant_method: str | None, exclude_modules: list[str], ) -> None: super().__init__(exclude_modules) self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.kv_cache_quant_method = kv_cache_quant_method if is_checkpoint_fp8_serialized: logger.warning( "Detected ModelOpt fp8 checkpoint. Please note that" " the format is experimental and could change." ) def get_name(self) -> QuantizationMethods: return "modelopt" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 89 @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: """Detect if this ModelOpt config should be used based on quantization config.""" if hf_quant_cfg is None: return None # Use the community standard 'quant_method' quant_method = hf_quant_cfg.get("quant_method", "").lower() # Only proceed if the method is explicitly "modelopt" if quant_method != "modelopt": return None # Look for ModelOpt-specific config structure if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): quant_algo = quant_config.get("quant_algo", "") if "FP8" in quant_algo: return "modelopt" else: # Check for compressed-tensors style config with specific quant_algo quant_algo = hf_quant_cfg.get("quant_algo", "") if isinstance(quant_algo, str) and "FP8" in quant_algo: return "modelopt" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], **kwargs: Any, ) -> "ModelOptFp8Config": is_checkpoint_fp8_serialized = "FP8" in quant_method return cls(is_checkpoint_fp8_serialized, kv_cache_quant_method, exclude_modules) class ModelOptFp8LinearMethod(LinearMethodBase): """Linear method for Model Optimizer static quantization. Supports loading FP8 checkpoints with static weight scale and activation scale. Future support might be added for dynamic scales. Limitations: 1. Only support per-tensor quantization due to torch._scaled_mm support. 2. Only support float8_e4m3fn datatype Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptFp8Config) -> None: self.quant_config = quant_config self.fp8_linear = Fp8LinearOp( act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE weight_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) def process_weights_after_loading(self, layer: Module) -> None: weight = layer.weight max_w_scale = layer.weight_scale.max() if not (layer.weight_scale == layer.weight_scale[0]).all(): max_w_scale, weight = requantize_with_max_scale( layer.weight, layer.weight_scale, layer.logical_widths ) layer.weight = Parameter(weight.t(), requires_grad=False) layer.weight_scale = Parameter(max_w_scale, requires_grad=False) layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: return self.fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, ) class ModelOptFp8MoEMethod(FusedMoEMethodBase): """MoE method for ModelOpt FP8. Supports loading FP8 checkpoints with static weight scale and activation scale. Args: quant_config: The ModelOpt quantization config. """ def __init__( self, quant_config: ModelOptFp8Config, layer: FusedMoE, ) -> None: super().__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( cutlass_fp8_supported, ) self.cutlass_fp8_supported = cutlass_fp8_supported() self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe(): self.flashinfer_moe_backend = get_flashinfer_moe_backend() if ( self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and not self.moe.is_act_and_mul ): logger.info_once( "Non-gated MoE is not supported for min-latency mode," "falling back to high-throughput mode" ) self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" ) def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( self.moe ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None experts = select_cutlass_fp8_gemm_impl( self.moe, self.moe_quant_config, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): # Use FP8 dtype if checkpoint is serialized weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype ) weight_loader = extra_weight_attrs.get("weight_loader") if self.moe.is_act_and_mul: w13_up_dim = 2 * intermediate_size_per_partition else: w13_up_dim = intermediate_size_per_partition w13_weight = ModelWeightParameter( data=torch.empty( num_experts, w13_up_dim, hidden_size, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=weight_dtype, ), input_dim=2, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALES - Per-tensor scaling for ModelOpts # For gated MoE, allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. # For non-gated MoE, allocate 1 scale for w13. if self.moe.is_act_and_mul: w13_weight_scale_shape = (num_experts, 2) else: w13_weight_scale_shape = (num_experts, 1) w13_weight_scale = PerTensorScaleParameter( data=torch.full( w13_weight_scale_shape, 1.0, dtype=torch.float32, ), weight_loader=weight_loader, ) w2_weight_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Set weight loader attributes for scales extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) # INPUT SCALES - Per-tensor scaling for ModelOpt w13_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) w2_input_scale = PerTensorScaleParameter( data=torch.full((num_experts,), 1.0, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """Process FP8 MoE weights after loading from serialized checkpoint. Only supports pre-quantized checkpoints with FP8 weights and scales. """ if self.flashinfer_moe_backend is not None: self._maybe_pad_intermediate_for_flashinfer(layer) layer.w13_weight = Parameter(layer.w13_weight.data, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False) from vllm._custom_ops import scaled_fp8_quant from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( per_tensor_dequantize, ) # Handle scale parameters if hasattr(layer, "w13_weight_scale") and layer.w13_weight_scale is not None: # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max of the w1 and w3 scales # then dequant and requant each expert. if ( layer.w13_weight_scale.dim() == 2 and layer.w13_weight_scale.shape[1] == 2 ): assert self.moe.is_act_and_mul, ( "w13_weight_scale should have 2 elements per expert " "only for gated MoE" ) # Get the maximum scale across w1 and w3 for each expert max_w13_scales = layer.w13_weight_scale.max(dim=1).values # Requantize each expert's weights using the combined scale # w13_weight (num_experts, 2 * intermediate_size, hidden_size) # where the first intermediate_size rows are w1, the next are w3 intermediate_size = layer.w13_weight.shape[1] // 2 for expert_id in range(layer.w13_weight.shape[0]): start = 0 for shard_id in range(2): # w1 and w3 # Dequantize using the original scale for this shard dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][ start : start + intermediate_size, : ], layer.w13_weight_scale[expert_id][shard_id], ) # Requantize using the combined max scale ( layer.w13_weight[expert_id][ start : start + intermediate_size, : ], _, ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) start += intermediate_size # Update the scale parameter to be per-expert layer.w13_weight_scale = Parameter(max_w13_scales, requires_grad=False) else: layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data, requires_grad=False ) if hasattr(layer, "w2_weight_scale") and layer.w2_weight_scale is not None: layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data, requires_grad=False ) # Input scales must be equal for each expert in fp8 MoE layers. if hasattr(layer, "w13_input_scale") and layer.w13_input_scale is not None: layer.w13_input_scale = Parameter( layer.w13_input_scale.max(), requires_grad=False ) if hasattr(layer, "w2_input_scale") and layer.w2_input_scale is not None: layer.w2_input_scale = Parameter( layer.w2_input_scale.max(), requires_grad=False ) if self.flashinfer_moe_backend is not None: if self.moe.is_act_and_mul: layer.w13_weight.data = swap_w13_to_w31(layer.w13_weight.data) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: rotate_flashinfer_fp8_moe_weights(layer.w13_weight, layer.w2_weight) register_moe_scaling_factors(layer) def _maybe_pad_intermediate_for_flashinfer(self, layer: torch.nn.Module) -> None: """Pad intermediate size so FlashInfer kernels' alignment constraints hold. Some FlashInfer FP8 MoE kernels require the (gated) intermediate size used for GEMM to be divisible by a small alignment value. When this is not satisfied (e.g. with certain tensor-parallel sizes), we pad the gate/up and down projection weights along the intermediate dim. """ if not hasattr(layer, "w13_weight") or not hasattr(layer, "w2_weight"): return # Current local intermediate size (per partition) is the K dimension of # the down projection. num_experts, hidden_size, intermediate = layer.w2_weight.shape min_alignment = 16 padded_intermediate = round_up(intermediate, min_alignment) if padded_intermediate == intermediate: return logger.info( "Padding intermediate size from %d to %d for up/down projection weights.", intermediate, padded_intermediate, ) up_mult = 2 if self.moe.is_act_and_mul else 1 padded_gate_up_dim = up_mult * padded_intermediate # Pad w13 and w12 along its intermediate dimension. w13 = layer.w13_weight.data padded_w13 = w13.new_zeros((num_experts, padded_gate_up_dim, hidden_size)) padded_w13[:, : w13.shape[1], :] = w13 layer.w13_weight.data = padded_w13 w2 = layer.w2_weight.data padded_w2 = w2.new_zeros((num_experts, hidden_size, padded_intermediate)) padded_w2[:, :, :intermediate] = w2 layer.w2_weight.data = padded_w2 if hasattr(layer, "intermediate_size_per_partition"): layer.intermediate_size_per_partition = padded_intermediate def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, g1_alphas=layer.output1_scales_gate_scalar.squeeze(), w2_scale=layer.w2_weight_scale, g2_alphas=layer.output2_scales_scalar.squeeze(), a1_scale=layer.w13_input_scale, a1_gscale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, a2_gscale=layer.w2_input_scale_inv, per_act_token_quant=False, ) def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: if layer.enable_eplb: raise NotImplementedError( "EPLB not supported for `ModelOptFp8MoEMethod` yet." ) assert layer.activation == "silu", ( f"Expected 'silu' activation but got {layer.activation}" ) assert not layer.renormalize return apply_flashinfer_per_tensor_scale_fp8( layer=layer, hidden_states=x, router_logits=router_logits, routing_bias=layer.e_score_correction_bias, global_num_experts=layer.global_num_experts, top_k=layer.top_k, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) # Expert selection topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert layer.activation in ("silu", "relu2_no_mul"), ( "Expected activation to be in ('silu', 'relu2_no_mul')," f"but got {layer.activation}" ) return flashinfer_cutlass_moe_fp8( x, layer, topk_weights, topk_ids, inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts assert self.moe_quant_config is not None return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=layer.activation, quant_config=self.moe_quant_config, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) ModelOptFp8Config.LinearMethodCls = ModelOptFp8LinearMethod ModelOptFp8Config.FusedMoEMethodCls = ModelOptFp8MoEMethod ModelOptFp8Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod class ModelOptNvFp4Config(ModelOptQuantConfigBase): """Config class for ModelOpt FP4.""" def __init__( self, is_checkpoint_nvfp4_serialized: bool, kv_cache_quant_algo: str | None, exclude_modules: list[str], group_size: int = 16, ) -> None: super().__init__(exclude_modules) self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized if is_checkpoint_nvfp4_serialized: logger.warning( "Detected ModelOpt NVFP4 checkpoint. Please note that" " the format is experimental and could change in future." ) self.group_size = group_size self.kv_cache_quant_algo = kv_cache_quant_algo def get_name(self) -> QuantizationMethods: return "modelopt_fp4" def get_supported_act_dtypes(self) -> list[torch.dtype]: return [torch.bfloat16, torch.half, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: """Detect if this ModelOpt FP4 config should be used based on quantization config.""" if hf_quant_cfg is None: return None # Use the community standard 'quant_method' quant_method = hf_quant_cfg.get("quant_method", "").lower() # Only proceed if the method is explicitly "modelopt" if quant_method != "modelopt": return None # Look for ModelOpt-specific config structure if "quantization" in hf_quant_cfg: quant_config = hf_quant_cfg["quantization"] if isinstance(quant_config, dict): quant_algo = quant_config.get("quant_algo", "") if "NVFP4" in quant_algo: return "modelopt_fp4" else: # Check for compressed-tensors style config with specific # quant_algo field quant_algo = hf_quant_cfg.get("quant_algo", "") if isinstance(quant_algo, str) and "FP4" in quant_algo.upper(): return "modelopt_fp4" return None @classmethod def _from_config( cls, *, quant_method: str, kv_cache_quant_method: str | None, exclude_modules: list[str], original_config: dict[str, Any], group_size: int | None, **kwargs: Any, ) -> "ModelOptNvFp4Config": is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method if group_size is None: group_size = 16 # Default value # For FP4, these fields are required if is_checkpoint_nvfp4_serialized and "quantization" in original_config: # Check if required fields are present in the quantization config quant_config = original_config["quantization"] required_fields = ["group_size", "kv_cache_quant_algo", "exclude_modules"] missing_fields = [ field for field in required_fields if field not in quant_config ] if missing_fields: raise ValueError( f"NVFP4 quantization requires the following fields in " f"hf_quant_config.json: {missing_fields}" ) return cls( is_checkpoint_nvfp4_serialized, kv_cache_quant_method, exclude_modules, group_size, ) class ModelOptNvFp4LinearMethod(LinearMethodBase): """Linear method for Model Optimizer NVFP4. Supports loading NVFP4 checkpoints with the following structure: input_scale: torch.float32, scalar , weight: NVFP4(represented as byte) Shape: [1, X, y/2] weight_scale: FP8-E4M3, Shape: [X, Y], aka per block scale, weight_scale_2: torch.float32, scalar, Args: quant_config: The ModelOpt quantization config. """ def __init__(self, quant_config: ModelOptNvFp4Config) -> None: self.quant_config = quant_config self.marlin_input_dtype = None self.backend = "none" if envs.VLLM_NVFP4_GEMM_BACKEND is None: if has_flashinfer(): self.backend = "flashinfer-cutlass" elif cutlass_fp4_supported(): self.backend = "cutlass" elif is_fp4_marlin_supported(): self.backend = "marlin" elif envs.VLLM_NVFP4_GEMM_BACKEND.startswith("flashinfer-"): self.backend = envs.VLLM_NVFP4_GEMM_BACKEND assert has_flashinfer(), f"FlashInfer is required for {self.backend}" elif envs.VLLM_NVFP4_GEMM_BACKEND == "cutlass": self.backend = "cutlass" assert cutlass_fp4_supported(), f"Cutlass is required for {self.backend}" if self.backend == "none": raise ValueError( "No valid NVFP4 GEMM backend found. " "Please check your platform capability." ) logger.info_once(f"Using {self.backend} for NVFP4 GEMM") def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del input_size, output_size if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( "NVFP4 quantization was selected, " " dynamic quantization is not supported." ) output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition if input_size_per_partition % 16 != 0: raise ValueError( "Unsupported model when in features size is not multiple of 16" ) # The nvfp4 weight is still represented as weight_dtype = ( torch.float8_e4m3fn if self.quant_config.is_checkpoint_nvfp4_serialized else params_dtype ) # Weight weight = ModelWeightParameter( data=torch.empty( # 2 fp4 items are packed in the input dimension layer.output_size_per_partition, layer.input_size_per_partition // 2, dtype=torch.uint8, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) # Input Weight Scale input_scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("input_scale", input_scale) # Global Weight Scale weight_scale_2 = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("weight_scale_2", weight_scale_2) # Per Block Weight Scale weight_scale = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition // self.quant_config.group_size, dtype=weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def process_weights_after_loading(self, layer: Module) -> None: # global scales: input_scale_2 = layer.input_scale.max().to(torch.float32) layer.input_scale = Parameter(input_scale_2, requires_grad=False) weight_scale_2 = layer.weight_scale_2.max().to(torch.float32) layer.weight_scale_2 = Parameter(weight_scale_2, requires_grad=False) layer.alpha = Parameter( layer.input_scale * layer.weight_scale_2, requires_grad=False ) # Calculate `1 / input_scale` so that we don't need to do so at runtime layer.input_scale_inv = Parameter( (1 / layer.input_scale).to(torch.float32), requires_grad=False ) # Swizzle the weight blockscale. # contracting dimension is input dimension # block_size = 16; assert layer.weight_scale.dtype == torch.float8_e4m3fn, ( "Weight Block scale must be represented as FP8-E4M3" ) if self.backend == "marlin": prepare_fp4_layer_for_marlin(layer) del layer.alpha del layer.input_scale elif self.backend == "flashinfer-trtllm": # FlashInfer TRTLLM FP4 GEMM requires a different weight layout. # FlashInfer provides nvfp4_quantize to quantize + shuffle the # layout but we use our own quantization so we have to call # shuffles ourselves. from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a weight = layer.weight.data weight_scale = layer.weight_scale.data epilogue_tile_m = 128 weight = shuffle_matrix_a(weight.view(torch.uint8), epilogue_tile_m) weight_scale = ( shuffle_matrix_sf_a(weight_scale.view(torch.uint8), epilogue_tile_m) .reshape(weight_scale.shape) .view(torch.float8_e4m3fn) ) layer.weight_scale = Parameter(weight_scale, requires_grad=False) layer.weight = Parameter(weight, requires_grad=False) else: swizzled_weight_scale = swizzle_blockscale(layer.weight_scale) layer.weight_scale = Parameter(swizzled_weight_scale, requires_grad=False) layer.weight = Parameter(layer.weight.data, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: if self.backend == "marlin": return apply_fp4_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, weight_scale_2=layer.weight_scale_2, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias, input_dtype=self.marlin_input_dtype, ) output_dtype = x.dtype output_shape = [x.shape[0], layer.weight.shape[0]] # quantize BF16 or FP16 to (FP4 and interleaved block scale) x_fp4, x_blockscale = scaled_fp4_quant(x, layer.input_scale_inv) # validate dtypes of quantized input, input block scale, # weight and weight_blockscale assert x_fp4.dtype == torch.uint8 assert layer.weight.dtype == torch.uint8 assert x_blockscale.dtype == torch.float8_e4m3fn assert layer.weight_scale.dtype == torch.float8_e4m3fn assert layer.alpha.dtype == torch.float32 mm_args = ( x_fp4, layer.weight, x_blockscale, layer.weight_scale, layer.alpha, output_dtype, ) if self.backend.startswith("flashinfer-"): backend_name = self.backend[len("flashinfer-") :] out = flashinfer_scaled_fp4_mm(*mm_args, backend=backend_name) else: assert self.backend == "cutlass" out = cutlass_scaled_fp4_mm(*mm_args) if bias is not None: out = out + bias return out.view(*output_shape) class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): """ MoE Method for FP4 Quantization. Args: quant_config: NVFP4 Quant Config """ def __init__( self, quant_config: ModelOptNvFp4Config, layer: FusedMoE, ) -> None: from vllm.model_executor.layers.quantization.utils.nvfp4_moe_support import ( detect_nvfp4_moe_support, # noqa: E501 ) super().__init__(layer.moe_config) self.quant_config = quant_config self.layer = layer _nvfp4 = detect_nvfp4_moe_support(self.__class__.__name__) self.cutlass_nvfp4_supported = _nvfp4.cutlass_supported self.allow_flashinfer = _nvfp4.allow_flashinfer self.use_marlin = _nvfp4.use_marlin self.marlin_input_dtype = None self.flashinfer_moe_backend = None if self.allow_flashinfer: self.flashinfer_moe_backend = get_flashinfer_moe_backend() logger.info_once( f"Using FlashInfer {self.flashinfer_moe_backend.value} kernels" " for ModelOptNvFp4FusedMoE." ) elif self.use_marlin: logger.info_once("Using Marlin for ModelOptNvFp4FusedMoE.") else: logger.info_once("Using Cutlass for ModelOptNvFp4FusedMoE.") def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> mk.FusedMoEPrepareAndFinalize | None: if self.use_marlin or ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None elif ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ): # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( self.moe ) logger.debug_once("%s", prepare_finalize.__class__.__name__) return prepare_finalize else: return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEPermuteExpertsUnpermute: assert self.moe_quant_config is not None experts = select_nvfp4_gemm_impl( self.moe, self.moe_quant_config, allow_flashinfer=self.allow_flashinfer, ) logger.debug_once("Using %s", experts.__class__.__name__) return experts def uses_weight_scale_2_pattern(self) -> bool: """ FP4 variants use 'weight_scale_2' pattern for per-tensor weight scales. """ return True def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): if not self.quant_config.is_checkpoint_nvfp4_serialized: raise ValueError( "NVFP4 quantization was selected, " " dynamic quantization is not supported." ) layer.num_experts = num_experts layer.params_dtype = params_dtype layer.quant_config = self.quant_config weight_dtype = torch.uint8 weight_scale_dtype = torch.float8_e4m3fn weight_loader = extra_weight_attrs.get("weight_loader") global_num_experts = extra_weight_attrs.get("global_num_experts") # GEMM 1 w13_weight = ModelWeightParameter( data=torch.empty( num_experts, (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight", w13_weight) # GEMM 2 w2_weight = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // 2, dtype=weight_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight", w2_weight) w13_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, (2 if self.moe.is_act_and_mul else 1) * intermediate_size_per_partition, # 2 fp4 items are packed in the input dimension hidden_size // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = ModelWeightParameter( data=torch.empty( num_experts, hidden_size, # 2 fp4 items are packed in the input dimension intermediate_size_per_partition // self.quant_config.group_size, dtype=weight_scale_dtype, ), input_dim=1, output_dim=2, weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} ) w13_weight_scale_2 = PerTensorScaleParameter( data=torch.empty( num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32 ), weight_loader=weight_loader, ) layer.register_parameter("w13_weight_scale_2", w13_weight_scale_2) w2_weight_scale_2 = PerTensorScaleParameter( data=torch.empty(num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_weight_scale_2", w2_weight_scale_2) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( self.flashinfer_moe_backend ) global_scale_num_experts = global_num_experts if use_global_sf else num_experts w13_input_scale = PerTensorScaleParameter( data=torch.empty( global_scale_num_experts, 2 if self.moe.is_act_and_mul else 1, dtype=torch.float32, ), weight_loader=weight_loader, ) layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = PerTensorScaleParameter( data=torch.empty(global_scale_num_experts, dtype=torch.float32), weight_loader=weight_loader, ) layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # GEMM 1 processing gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data if ( self.allow_flashinfer and ( self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ) and self.moe.is_act_and_mul ): gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2 ) layer.w13_weight = Parameter(gemm1_weight, requires_grad=False) layer.w13_weight_scale = Parameter(gemm1_weight_scale, requires_grad=False) # Common processing for w13_weight_scale_2 if self.moe.is_act_and_mul and not torch.allclose( layer.w13_weight_scale_2[:, 0], layer.w13_weight_scale_2[:, 1] ): logger.warning_once( "w1_weight_scale_2 must match w3_weight_scale_2. " "Accuracy may be affected." ) w13_weight_scale_2 = layer.w13_weight_scale_2[:, 0].contiguous() layer.w13_weight_scale_2 = Parameter(w13_weight_scale_2, requires_grad=False) # Common processing for input scales and alphas use_global_sf = self.allow_flashinfer and is_flashinfer_supporting_global_sf( self.flashinfer_moe_backend ) if use_global_sf: # For backends provide by Flashinfer, the input global scales are # shared across all experts. w13_input_scale = ( layer.w13_input_scale.max().to(torch.float32).expand(layer.num_experts) ) else: w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32) layer.g1_alphas = Parameter( (w13_input_scale * w13_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w13_input_scale_quant = Parameter( (1 / w13_input_scale).to(torch.float32), requires_grad=False ) # GEMM 2 processing if use_global_sf: # For backends provide by Flashinfer, the input global scales are # shared across all experts. w2_input_scale = ( layer.w2_input_scale.max().to(torch.float32).expand(layer.num_experts) ) else: w2_input_scale = layer.w2_input_scale layer.g2_alphas = Parameter( (w2_input_scale * layer.w2_weight_scale_2).to(torch.float32), requires_grad=False, ) # This is for quantization, so we need to invert it. layer.w2_input_scale_quant = Parameter( (1 / w2_input_scale).to(torch.float32), requires_grad=False ) # TensorRT-LLM specific processing if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): # Prepare static weights for TRT-LLM kernel # alternate: prepare_static_weight_layouts_for_trtllm_moe ( gemm1_weights_fp4_shuffled, gemm1_scales_fp4_shuffled, gemm2_weights_fp4_shuffled, gemm2_scales_fp4_shuffled, ) = prepare_static_weights_for_trtllm_fp4_moe( layer.w13_weight, layer.w2_weight, layer.w13_weight_scale, layer.w2_weight_scale, layer.w2_weight.size(-2), # hidden_size layer.w13_weight.size(-2) // 2, # intermediate_size layer.w13_weight.size(0), # num_experts ) logger.debug_once("Finished shuffling weights for TRT-LLM MOE") layer.gemm1_weights_fp4_shuffled = Parameter( gemm1_weights_fp4_shuffled, requires_grad=False ) layer.gemm2_weights_fp4_shuffled = Parameter( gemm2_weights_fp4_shuffled, requires_grad=False ) layer.gemm1_scales_fp4_shuffled = Parameter( gemm1_scales_fp4_shuffled, requires_grad=False ) layer.gemm2_scales_fp4_shuffled = Parameter( gemm2_scales_fp4_shuffled, requires_grad=False ) # Additional parameter needed for TRT-LLM layer.g1_scale_c = Parameter( (layer.w2_input_scale_quant * layer.g1_alphas).to(torch.float32), requires_grad=False, ) # Clean up weights that won't be used by TRT-LLM del layer.w2_weight del layer.w2_weight_scale del layer.w13_weight del layer.w13_weight_scale elif self.use_marlin: # Marlin processing prepare_moe_fp4_layer_for_marlin(layer) del layer.g1_alphas del layer.g2_alphas del layer.w13_input_scale_quant del layer.w2_input_scale_quant else: # Non-TRT-LLM processing (Cutlass or non-flashinfer) w13_blockscale_swizzled = swizzle_blockscale(layer.w13_weight_scale) layer.w13_weight_scale = Parameter( w13_blockscale_swizzled, requires_grad=False ) w13_weight = layer.w13_weight intermediate_size_pad = w13_blockscale_swizzled.size(1) - w13_weight.size(1) if intermediate_size_pad: # padding gated activations will require to split w1 and w3 # and pad them individually assert not self.moe.is_act_and_mul, ( "The intermediate size required padding, " "but padding is not implemented for gated activations" ) layer.w13_weight = Parameter( torch.nn.functional.pad( w13_weight, (0, 0, 0, intermediate_size_pad) ), requires_grad=False, ) layer.w2_weight = Parameter( torch.nn.functional.pad( layer.w2_weight, (0, intermediate_size_pad // 2, 0, 0) ), requires_grad=False, ) layer.w2_weight_scale = Parameter( torch.nn.functional.pad( layer.w2_weight_scale, (0, intermediate_size_pad // 16) ), requires_grad=False, ) w2_blockscale_swizzled = swizzle_blockscale(layer.w2_weight_scale) layer.w2_weight_scale = Parameter( w2_blockscale_swizzled, requires_grad=False ) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: if ( self.use_marlin or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None return nvfp4_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, g1_alphas=layer.g1_alphas, g2_alphas=layer.g2_alphas, a1_gscale=layer.w13_input_scale_quant, a2_gscale=layer.w2_input_scale_quant, ) @property def supports_eplb(self) -> bool: return True def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: if not self.moe.is_act_and_mul: assert ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ), ( "Non-gated activations are only supported by the" " flashinfer CUTLASS backend for modelopt checkpoints" ) if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM and not layer.enable_eplb ): return flashinfer_trtllm_fp4_moe( layer=layer, x=x, router_logits=router_logits, top_k=layer.top_k, global_num_experts=layer.global_num_experts, num_expert_group=layer.num_expert_group, topk_group=layer.topk_group, custom_routing_function=layer.custom_routing_function, e_score_correction_bias=layer.e_score_correction_bias, ) topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, ) # EPLB path if ( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return flashinfer_trtllm_fp4_routed_moe( layer=layer, x=x, topk_ids=topk_ids, topk_weights=topk_weights, top_k=layer.top_k, global_num_experts=layer.global_num_experts, ) if self.use_marlin: return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, None, None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, topk_weights, topk_ids, global_scale1=layer.w13_weight_scale_2, global_scale2=layer.w2_weight_scale_2, quant_type_id=scalar_types.float4_e2m1f.id, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, input_dtype=self.marlin_input_dtype, ) elif self.allow_flashinfer: assert self.flashinfer_moe_backend in ( FlashinferMoeBackend.CUTLASS, FlashinferMoeBackend.CUTEDSL, ) if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 flashinfer_cutlass_moe_fp4, ) flashinfer_fn_moe_fp4 = flashinfer_cutlass_moe_fp4 else: from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( # noqa: E501 flashinfer_cutedsl_moe_fp4, ) flashinfer_fn_moe_fp4 = flashinfer_cutedsl_moe_fp4 assert self.moe_quant_config is not None return flashinfer_fn_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, inplace=False, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 assert self.moe_quant_config is not None return cutlass_moe_fp4( a=x, w1_fp4=layer.w13_weight, w2_fp4=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, quant_config=self.moe_quant_config, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, # TODO: derive from arguments m=x.shape[0], n=layer.w2_weight.shape[2] * 2, k=x.shape[1], e=layer.w13_weight.shape[0], ) ModelOptNvFp4Config.LinearMethodCls = ModelOptNvFp4LinearMethod ModelOptNvFp4Config.FusedMoEMethodCls = ModelOptNvFp4FusedMoE ModelOptNvFp4Config.KVCacheMethodCls = ModelOptFp8KVCacheMethod