# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, ocp_mx_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin, ) from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import ( OCP_MX_BLOCK_SIZE, OCP_MX_Scheme, ) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, ) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkOCP_MX_MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 module: torch.nn.Module, layer_name: str, ) -> "QuarkMoEMethod": layer_quant_config = quant_config._find_matched_config(layer_name, module) if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"): raise NotImplementedError( "Currently, Quark models with " "output_tensors and bias " "quantized are not supported" ) weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_ocp_mx(weight_config, input_config): return QuarkOCP_MX_MoEMethod(weight_config, input_config, module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any], moe: FusedMoEConfig, ): super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config self.weight_qscheme = self.weight_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme") per_tensor = ( self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor" ) per_channel = ( self.weight_qscheme == "per_channel" and self.input_qscheme == "per_channel" ) self.act_quant_group_shape = ( GroupShape.PER_TOKEN if per_channel else GroupShape.PER_TENSOR ) if not (per_tensor or per_channel): raise ValueError( "For FP8 Fused MoE layers, only per-tensor and per-channel " "scales for weights and activations are supported. Found " f"{self.weight_qscheme}, {self.input_qscheme}" ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") if self.static_input_scales and per_channel: raise ValueError( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization." ) # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = ( not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False self.rocm_aiter_moe_enabled = rocm_aiter_ops.is_fused_moe_enabled() def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None params_dtype = torch.float8_e4m3fn # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES if self.weight_qscheme == "per_tensor": # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_qscheme == "per_channel": # quark's scale is 1 dim. w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, hidden_size, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: w13_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.float32), requires_grad=False ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: if layer.w13_input_scale is None or layer.w2_input_scale is None: raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None." ) if not all_close_1d(layer.w13_input_scale) or not all_close_1d( layer.w2_input_scale ): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. " ) layer.w13_input_scale = torch.nn.Parameter( layer.w13_input_scale.max(), requires_grad=False ) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False ) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = ( normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale ) ) w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale ) # Reset the parameter layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter( w13_weight_scale, requires_grad=False ) if w13_input_scale is not None: layer.w13_input_scale = torch.nn.Parameter( w13_input_scale, requires_grad=False ) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter( w2_weight_scale, requires_grad=False ) if w2_input_scale is not None: layer.w2_input_scale = torch.nn.Parameter( w2_input_scale, requires_grad=False ) # For per-tensor case, Fp8 moe kernel needs single weight scale # for w13 per expert. Use max then dequant and requant each expert. if self.weight_qscheme == "per_tensor": assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start : start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id], ) layer.w13_weight[expert_id][start : start + shard_size, :], _ = ( ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id]) ) start += shard_size layer.w13_weight_scale = torch.nn.Parameter( max_w13_scales, requires_grad=False ) # quark's scale is 1 dim. elif self.weight_qscheme == "per_channel": if self.act_quant_group_shape == GroupShape.PER_TOKEN: w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale = torch.nn.Parameter( w13_weight_scale, requires_grad=False ) w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale = torch.nn.Parameter( w2_weight_scale, requires_grad=False ) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights( layer.w13_weight.data, layer.w2_weight.data ) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, per_act_token_quant=self.input_qscheme == "per_channel", per_out_ch_quant=self.weight_qscheme == "per_channel", ) def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, ) if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, ) return rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, quant_config=self.moe_quant_config, expert_map=layer.expert_map, ) elif self.use_marlin: assert layer.activation == "silu", ( f"{layer.activation} not supported for Marlin MoE." ) return fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, None, None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, ) else: from vllm.model_executor.layers.fused_moe import fused_experts return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any], moe: FusedMoEConfig, ): super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config weight_qscheme = self.weight_quant.get("qscheme") input_qscheme = self.input_quant.get("qscheme") if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): raise ValueError( "For MX(FP4) Fused MoE layers, only per-group scales " "for weights and activations are supported. Found " f"{weight_qscheme}, {input_qscheme}" ) # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") self.weight_dtype = self.weight_quant["dtype"].replace("fp", "mxfp") self.input_dtype = self.input_quant["dtype"].replace("fp", "mxfp") self.fp4_dtype = getattr(torch, "float4_e2m1fn_x2", None) self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype( self.input_dtype, self.weight_dtype ) if self.static_input_scales: raise NotImplementedError( "QuarkOCP_MX_MoEMethod with static input scales is currently " "not implemented. Please open an issue." ) self.use_rocm_aiter_moe = rocm_aiter_ops.is_fused_moe_enabled() self.emulate = not current_platform.supports_mx() or not ( self.use_rocm_aiter_moe and self.ocp_mx_scheme == "w_mxfp4_a_mxfp4" ) if self.emulate: logger.warning_once( f"The current mode (supports_mx={current_platform.supports_mx()}, " f"use_mxfp4_aiter_moe={self.use_rocm_aiter_moe}, " f"ocp_mx_scheme={self.ocp_mx_scheme}) " "does not support native MXFP4/MXFP6 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision." ) else: logger.warning_once( "The current mode supports native MoE MXFP4 computation" ) def get_packed_dim(self, dim: int, quant_dtype: str): if quant_dtype == "mxfp4": assert dim % 2 == 0 return dim // 2 else: # FP6 packs 4 * 6 = 24 bits on 3 bytes. assert (dim * 3) % 4 == 0 return (dim * 3) // 4 def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} ) params_dtype = torch.uint8 # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, self.get_packed_dim(hidden_size, self.weight_dtype), dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, self.get_packed_dim(intermediate_size_per_partition, self.weight_dtype), dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, 2 * intermediate_size_per_partition, hidden_size // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones( num_experts, hidden_size, intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, ) set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) def process_weights_after_loading(self, layer): if self.emulate: return from aiter.utility.fp4_utils import e8m0_shuffle # Pre-shuffle weight scales s0, s1, _ = layer.w13_weight_scale.shape w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1) w13_weight_scale = e8m0_shuffle(w13_weight_scale) layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1) s0, s1, _ = layer.w2_weight_scale.shape w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1) w2_weight_scale = e8m0_shuffle(w2_weight_scale) layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1) if self.fp4_dtype is not None: layer.w13_weight = torch.nn.Parameter( layer.w13_weight.view(self.fp4_dtype), requires_grad=layer.w13_weight.requires_grad, ) layer.w2_weight = torch.nn.Parameter( layer.w2_weight.view(self.fp4_dtype), requires_grad=layer.w2_weight.requires_grad, ) torch.cuda.empty_cache() def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: return ocp_mx_moe_quant_config( quant_dtype=self.input_dtype, weight_dtype=self.weight_dtype, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=None, a2_scale=None, block_shape=None, ) @property def allow_inplace(self) -> bool: return True def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, ) if not self.emulate: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( rocm_aiter_fused_experts, ) out = rocm_aiter_fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, quant_config=self.moe_quant_config, expert_map=layer.expert_map, ) else: from vllm.model_executor.layers.fused_moe import fused_experts out = fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) return out