# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Callable, Optional, Union import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, mxfp4_w4a4_moe_quant_config) from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( is_rocm_aiter_moe_enabled) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( prepare_moe_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.mxfp4_utils import ( OCP_MX_BLOCK_SIZE) from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.scalar_type import scalar_types logger = init_logger(__name__) __all__ = [ "QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod" ] class QuarkMoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 module: torch.nn.Module, layer_name: str) -> "QuarkMoEMethod": layer_quant_config = quant_config._find_matched_config( layer_name, module) if (layer_quant_config.get("output_tensors") or layer_quant_config.get("bias")): raise NotImplementedError("Currently, Quark models with " "output_tensors and bias " "quantized are not supported") weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config, module.moe_config) elif quant_config._is_mx_fp4(weight_config, input_config): return QuarkW4A4MXFp4MoEMethod(weight_config, input_config, module.moe_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any], moe: FusedMoEConfig, ): super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config self.weight_qscheme = self.weight_quant.get("qscheme") self.input_qscheme = self.input_quant.get("qscheme") per_tensor = (self.weight_qscheme == "per_tensor" and self.input_qscheme == "per_tensor") per_channel = (self.weight_qscheme == "per_channel" and self.input_qscheme == "per_channel") self.act_quant_group_shape = GroupShape.PER_TOKEN \ if per_channel else GroupShape.PER_TENSOR if not (per_tensor or per_channel): raise ValueError( "For FP8 Fused MoE layers, only per-tensor and per-channel " "scales for weights and activations are supported. Found " f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") if self.static_input_scales and per_channel: raise ValueError( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): layer.intermediate_size_per_partition = intermediate_size_per_partition layer.hidden_size = hidden_size layer.num_experts = num_experts layer.orig_dtype = params_dtype layer.weight_block_size = None params_dtype = torch.float8_e4m3fn # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES if self.weight_qscheme == "per_tensor": # Allocate 2 scales for w1 and w3 respectively. # They are combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter(torch.ones( num_experts, 2, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-TENSOR quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) elif self.weight_qscheme == "per_channel": # quark's scale is 1 dim. w13_weight_scale = torch.nn.Parameter(torch.ones( num_experts, 2 * intermediate_size_per_partition, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter(torch.ones( num_experts, hidden_size, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add PER-CHANNEL quantization for FusedMoE.weight_loader. extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: w13_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") if (not all_close_1d(layer.w13_input_scale) or not all_close_1d(layer.w2_input_scale)): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") layer.w13_input_scale = torch.nn.Parameter( layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale) w2_weight, w2_weight_scale, w2_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale) # Reset the parameter layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False) if w13_input_scale is not None: layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) if w2_input_scale is not None: layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, requires_grad=False) # For per-tensor case, Fp8 moe kernel needs single weight scale # for w13 per expert. Use max then dequant and requant each expert. if self.weight_qscheme == "per_tensor": assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) # quark's scale is 1 dim. elif self.weight_qscheme == "per_channel": if self.act_quant_group_shape == GroupShape.PER_TOKEN: w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale = torch.nn.Parameter( w13_weight_scale, requires_grad=False) w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 rocm_aiter_fused_experts, shuffle_weights) # reshaping weights is required for aiter moe kernel. shuffled_w13, shuffled_w2 = shuffle_weights( layer.w13_weight.data, layer.w2_weight.data) layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale self.fused_experts_func = None else: from vllm.model_executor.layers.fused_moe import fused_experts self.fused_experts_func = fused_experts def get_fused_moe_quant_config( self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: return fp8_w8a8_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, per_act_token_quant=self.weight_qscheme == "per_channel", ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.") topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) if self.rocm_aiter_moe_enabled: return self.rocm_aiter_fused_experts_func( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, quant_config=self.moe_quant_config, expert_map=expert_map) if self.use_marlin: assert activation == "silu", ( f"{activation} not supported for Marlin MoE.") return torch.ops.vllm.fused_marlin_moe( x, layer.w13_weight, layer.w2_weight, None, None, layer.w13_weight_scale, layer.w2_weight_scale, router_logits, topk_weights, topk_ids, quant_type_id=scalar_types.float8_e4m3fn.id, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map) assert self.fused_experts_func is not None return self.fused_experts_func( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, quant_config=self.moe_quant_config) class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod): def __init__( self, weight_config: dict[str, Any], input_config: dict[str, Any], moe: FusedMoEConfig, ): super().__init__(moe) self.weight_quant = weight_config self.input_quant = input_config weight_qscheme = self.weight_quant.get("qscheme") input_qscheme = self.input_quant.get("qscheme") if not (weight_qscheme == "per_group" and input_qscheme == "per_group"): raise ValueError( "For MX(FP4) Fused MoE layers, only per-group scales " "for weights and activations are supported. Found " f"{weight_qscheme}, {input_qscheme}") # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") if self.static_input_scales: raise NotImplementedError( "QuarkW4A4MXFp4MoEMethod with static input scales is currently " "not implemented. Please open an issue.") if not current_platform.supports_mx(): self.emulate = True logger.warning_once( "The current platform does not support native MXFP4 " "computation. Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision.") else: self.emulate = True logger.warning_once( "The current platform supports native MXFP4 " "computation, but kernels are not yet integrated in vLLM. " "Simulated weight dequantization and activation " "QDQ (quantize and dequantize) will be used, with the linear " "layers computed in high precision.") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}) params_dtype = torch.uint8 # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size // 2, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition // 2, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, 2 * intermediate_size_per_partition, hidden_size // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones( num_experts, hidden_size, intermediate_size_per_partition // OCP_MX_BLOCK_SIZE, dtype=params_dtype, ), requires_grad=False, ) set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) def get_fused_moe_quant_config( self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]: return mxfp4_w4a4_moe_quant_config( w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=None, a2_scale=None, block_shape=None, ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: assert self.fused_experts is None if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.") from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, routed_scaling_factor=routed_scaling_factor, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) out = fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, quant_config=self.moe_quant_config, ) return out