# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Callable, Optional import torch import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform logger = init_logger(__name__) __all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] class QuarkMoEMethod(FusedMoEMethodBase): @staticmethod def get_moe_method( quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 module: torch.nn.Module, layer_name: str) -> "QuarkMoEMethod": layer_quant_config = quant_config._find_matched_config( layer_name, module) if (layer_quant_config.get("output_tensors") or layer_quant_config.get("bias")): raise NotImplementedError("Currently, Quark models with " "output_tensors and bias " "quantized are not supported") weight_config = layer_quant_config.get("weight") input_config = layer_quant_config.get("input_tensors") if quant_config._is_fp8_w8a8(weight_config, input_config): return QuarkW8A8Fp8MoEMethod(weight_config, input_config) else: raise RuntimeError("Unsupported FusedMoe scheme") class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]): self.weight_quant = weight_config self.input_quant = input_config weight_qscheme = self.weight_quant.get("qscheme") input_qscheme = self.input_quant.get("qscheme") if not (weight_qscheme == "per_tensor" and input_qscheme == "per_tensor"): raise ValueError( "For FP8 Fused MoE layers, only per-tensor scales " "for weights and activations are supported. Found " f"{weight_qscheme}, {input_qscheme}") # noqa E501 self.static_input_scales = not self.input_quant.get("is_dynamic") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): params_dtype = torch.float8_e4m3fn # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.static_input_scales: w13_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: layer.w13_input_scale = None layer.w2_input_scale = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. if self.static_input_scales: if (layer.w13_input_scale is None or layer.w2_input_scale is None): raise ValueError( "QuantConfig has static quantization, but found " "activation scales are None.") if (not all_close_1d(layer.w13_input_scale) or not all_close_1d(layer.w2_input_scale)): logger.warning_once( "Found input_scales that are not equal for " "fp8 MoE layer. Using the maximum across experts " "for each layer. ") layer.w13_input_scale = torch.nn.Parameter( layer.w13_input_scale.max(), requires_grad=False) layer.w2_input_scale = torch.nn.Parameter( layer.w2_input_scale.max(), requires_grad=False) if current_platform.is_fp8_fnuz(): # Normalize the weights and scales w13_weight, w13_weight_scale, w13_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale) w2_weight, w2_weight_scale, w2_input_scale = \ normalize_e4m3fn_to_e4m3fnuz( layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale) # Reset the parameter layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False) if w13_input_scale is not None: layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False) if w2_input_scale is not None: layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, requires_grad=False) # Fp8 moe kernel needs single weight scale for w13 per expert. # We take the max then dequant and requant each expert. assert layer.w13_weight_scale is not None shard_size = layer.intermediate_size_per_partition max_w13_scales = layer.w13_weight_scale.max(dim=1).values for expert_id in range(layer.local_num_experts): start = 0 for shard_id in range(2): dq_weight = per_tensor_dequantize( layer.w13_weight[expert_id][start:start + shard_size, :], layer.w13_weight_scale[expert_id][shard_id]) layer.w13_weight[expert_id][ start:start + shard_size, :], _ = ops.scaled_fp8_quant( dq_weight, max_w13_scales[expert_id]) start += shard_size layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, use_fp8_w8a8=True, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, w1_scale=layer.w13_weight_scale, w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale)