# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project import functools from functools import partial import importlib.util from typing import Any, Callable, Optional, Union import torch from torch.nn import Module from torch.nn.parameter import Parameter from typing import Any, Dict, List, Optional, Callable from vllm import envs from vllm import _custom_ops as ops from vllm._aiter_ops import rocm_aiter_ops from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.quantization.fp8 import ( get_flashinfer_moe_backend, ACTIVATION_SCHEMES, Fp8Config, Fp8LinearMethod, Fp8MoeBackend, Fp8MoEMethod, ) from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, ) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( W8A8BlockFp8LinearOp, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, validate_fp8_block_shape ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale, maybe_create_device_identity, Fp8LinearOp) from vllm.model_executor.parameter import ( BlockQuantScaleParameter, ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils.deep_gemm import ( is_deep_gemm_e8m0_used, is_deep_gemm_supported, ) from vllm.utils.flashinfer import has_flashinfer_moe from vllm.utils.import_utils import has_deep_gemm from vllm_mlu.mlu_hijack_utils import MluHijackObject from vllm_mlu.model_executor.layers.fused_moe.utils import _fp8_quantize import vllm_mlu._mlu_ops as mlu_ops logger = init_logger(__name__) def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. if ( current_platform.is_cuda() and ( current_platform.is_device_capability(100) or current_platform.is_device_capability(90) ) and envs.VLLM_USE_FLASHINFER_MOE_FP8 and has_flashinfer_moe() ): backend = get_flashinfer_moe_backend() if backend == FlashinferMoeBackend.TENSORRT_LLM: logger.info_once("Using FlashInfer FP8 MoE TRTLLM backend for SM100") return Fp8MoeBackend.FLASHINFER_TRTLLM else: if block_quant and current_platform.is_device_capability(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " "support block quantization. Please use " "VLLM_FLASHINFER_MOE_BACKEND=latency " "instead." ) logger.info_once("Using FlashInfer FP8 MoE CUTLASS backend for SM90/SM100") return Fp8MoeBackend.FLASHINFER_CUTLASS # weight-only path for older GPUs without native FP8 use_marlin = ( not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN ) ''' ============================= Modify by vllm_mlu ============================= @brief: disable marlin for MLU backend. ''' if current_platform.is_rocm() or current_platform.is_out_of_tree(): use_marlin = False ''' ================== End of MLU Hijack ================== ''' if use_marlin: logger.info_once("Using Marlin backend for FP8 MoE") return Fp8MoeBackend.MARLIN # deepGEMM on supported platforms with block-quantized weights if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): logger.info_once("Using DeepGEMM backend for FP8 MoE") return Fp8MoeBackend.DEEPGEMM # CUTLASS BlockScaled GroupedGemm on SM100 with block-quantized weights if ( current_platform.is_cuda() and current_platform.is_device_capability(100) and block_quant ): logger.info_once("Using Cutlass BlockScaled GroupedGemm backend for FP8 MoE") return Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM # default to Triton logger.info_once("Using Triton backend for FP8 MoE") return Fp8MoeBackend.TRITON Fp8Config____init____org = Fp8Config.__init__ def vllm__model_executor__layers__quantization__fp8__Fp8Config____init__( self, is_checkpoint_fp8_serialized: bool = False, activation_scheme: str = "dynamic", ignored_layers: list[str] | None = None, weight_block_size: list[int] | None = None, activation_quant_method: Optional[str] = None, weight_quant_method: Optional[str] = None, ) -> None: super(Fp8Config, self).__init__() Fp8Config____init____org( self, is_checkpoint_fp8_serialized, activation_scheme, ignored_layers, weight_block_size ) ''' ============================= Modify by vllm_mlu ============================= @brief: Add class members activation_quant_method and weight_quant_method to indicate the granularity of quantization. ''' self.activation_quant_method = activation_quant_method self.weight_quant_method = weight_quant_method assert (self.weight_block_size or \ self.activation_quant_method == "per_token" and self.weight_quant_method == "per_channel" and self.activation_scheme == "dynamic"), "Only support block-wise quantization, or "\ "input dynamic per-token weight per-channel quantization yet." ''' ================== End of MLU Hijack ================== ''' @classmethod def vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config( cls, config: Dict[str, Any] ) -> "Fp8Config": quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) if not ignored_layers: ignored_layers = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) ''' ============================= Modify by vllm_mlu ============================= @brief: Add config members activation_quant_method and weight_quant_method to indicate the granularity of quantization. ''' activation_quant_method = cls.get_from_keys_or(config, ["activation_quant_method"], 'per_token') weight_quant_method = cls.get_from_keys_or(config, ["weight_quant_method"], None) return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, activation_scheme=activation_scheme, ignored_layers=ignored_layers, weight_block_size=weight_block_size, activation_quant_method=activation_quant_method, weight_quant_method=weight_quant_method) ''' ================== End of MLU Hijack ================== ''' def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): maybe_create_device_identity() output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype layer.weight_block_size = None ''' ============================= Modify by vllm_mlu ============================= @brief: add tp_group. ''' tp_group = extra_weight_attrs.get("tp_group", None) ''' ================== End of MLU Hijack ================== ''' if self.block_quant: assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size validate_fp8_block_shape( layer, input_size, output_size, input_size_per_partition, output_partition_sizes, self.weight_block_size, ) ''' ============================= Modify by vllm_mlu ============================= @brief: add tp_group. ''' # WEIGHT if self.quant_config.is_checkpoint_fp8_serialized: weight = create_fp8_weight_parameter( output_size_per_partition, input_size_per_partition, weight_loader ) else: # For non-serialized checkpoints, use original dtype weight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=params_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, tp_group=tp_group, ) ''' ================== End of MLU Hijack ================== ''' layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: ''' ============================= Modify by vllm_mlu ============================= @brief: Support weight per channel quantization. @brief: Add tp_group to enable custom split. ''' if self.weight_per_channel: scale = ChannelQuantScaleParameter( data=torch.empty(sum(output_partition_sizes), dtype=torch.float32), output_dim=0, weight_loader=weight_loader, tp_group=tp_group, ) else: scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min set_weight_attrs(scale, {"scale_type": "weight_scale"}) layer.register_parameter("weight_scale", scale) ''' ================== End of MLU Hijack ================== ''' else: assert not self.act_q_static assert self.weight_block_size is not None scale = create_fp8_scale_parameter( BlockQuantScaleParameter, output_partition_sizes, input_size_per_partition, self.weight_block_size, weight_loader, ) set_weight_attrs(scale, {"scale_type": "weight_scale"}) # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.act_q_static: scale = create_fp8_input_scale(output_partition_sizes, weight_loader) set_weight_attrs(scale, {"scale_type": "input_scale"}) layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__( self, quant_config: Fp8Config ): self.quant_config = quant_config self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.out_dtype = torch.get_default_dtype() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = ( not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN ) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False if vllm_is_batch_invariant(): self.use_marlin = False # AITER is only supported on ROCm and only for FP8_FNUZ # and at the moment are MI300 series self.use_aiter_and_is_supported = rocm_aiter_ops.is_linear_fp8_enaled() self.use_deep_gemm = is_deep_gemm_supported() self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.weight_block_size is not None if self.block_quant: # Marlin doesn't support block-wise fp8 self.use_marlin = False self.act_q_static = self.quant_config.activation_scheme == "static" if self.weight_block_size: self.act_q_group_shape = GroupShape(1, self.weight_block_size[0]) else: # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN else: self.act_q_group_shape = GroupShape.PER_TENSOR ''' ============================= Modify by vllm_mlu ============================= @brief: Add config members activation_quant_method and weight_quant_method to indicate the granularity of quantization. ''' self.weight_per_channel = (self.quant_config.weight_quant_method == 'per_channel') self.activation_per_token = (self.quant_config.activation_quant_method == 'per_token') if self.weight_per_channel and self.activation_per_token: self.use_marlin = False ''' ================== End of MLU Hijack ================== ''' if self.block_quant: assert not self.act_q_static assert self.weight_block_size is not None self.w8a8_block_fp8_linear = W8A8BlockFp8LinearOp( weight_group_shape=GroupShape(*self.weight_block_size), act_quant_group_shape=self.act_q_group_shape, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, use_aiter_and_is_supported=self.use_aiter_and_is_supported, ) else: self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, act_quant_group_shape=self.act_q_group_shape, ) Fp8LinearMethod__process_weights_after_loading__org = Fp8LinearMethod.process_weights_after_loading def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading( self, layer: Module, ) -> None: ''' ============================= Modify by vllm_mlu ============================= @brief: For dynamic activation and channel-wise weight quantization, additional processing is not needed. ''' if (self.quant_config.is_checkpoint_fp8_serialized and self.weight_per_channel and self.quant_config.activation_scheme == "dynamic"): return ''' ================== End of MLU Hijack ================== ''' Fp8LinearMethod__process_weights_after_loading__org(self=self, layer=layer) def vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert residual is None, "Fp8Linear residual is not supported yet." # if batch invariant mode is enabled, prefer DeepGEMM FP8 path # we will use BF16 dequant when DeepGEMM is not supported. if vllm_is_batch_invariant(): if self.block_quant: assert self.weight_block_size is not None return self.w8a8_block_fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, ) else: # per-tensor/channel: dequant to BF16 and run GEMM weight_fp8 = layer.weight.to(torch.bfloat16) weight_scale = layer.weight_scale.to(torch.bfloat16) if weight_scale.numel() == 1: # Per-tensor: simple scalar multiplication weight_bf16 = weight_fp8 * weight_scale else: # Multiple scales (fused modules like QKV) # Try to infer correct broadcasting # weight is [K, N], scale could be [num_logical_weights] # Need to figure out how to broadcast - for now just try # direct multiplication if ( weight_scale.dim() == 1 and weight_scale.shape[0] == weight_fp8.shape[0] ): # Per-row scaling weight_bf16 = weight_fp8 * weight_scale.unsqueeze(1) else: # Fallback weight_bf16 = weight_fp8 * weight_scale return torch.nn.functional.linear(x, weight_bf16.t(), bias) if self.use_marlin: return apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias, ) if self.block_quant: assert self.weight_block_size is not None from vllm_mlu.model_executor.layers.quantization.utils.fp8_utils import ( apply_w8a8_block_fp8_linear) return apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, weight_scale=layer.weight_scale, input_scale=layer.input_scale, bias=bias, ) ''' ============================= Modify by vllm_mlu ============================= @brief: Use activation per token quantization based on quantization config. ''' return self.fp8_linear.apply( input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias, weight_per_channel=self.weight_per_channel, activation_per_token=self.activation_per_token) ''' ================== End of MLU Hijack ================== ''' def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__( self, quant_config: Fp8Config, layer: torch.nn.Module ): super(Fp8MoEMethod, self).__init__(layer.moe_config) self.layer = layer self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.fp8_backend = get_fp8_moe_backend(self.block_quant) self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: FlashinferMoeBackend | None = None if self.fp8_backend == Fp8MoeBackend.FLASHINFER_TRTLLM: self.flashinfer_moe_backend = FlashinferMoeBackend.TENSORRT_LLM elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: self.flashinfer_moe_backend = FlashinferMoeBackend.CUTLASS if self.block_quant: assert self.weight_block_size == [128, 128], ( f"Only support weight_block_size == [128, 128], " f"got {self.weight_block_size}" ) self.flashinfer_moe_fn = partial( flashinfer_cutlass_moe_fp8, moe=self.moe, use_deepseek_fp8_block_scale=self.block_quant, ) self.allow_deep_gemm = self.fp8_backend == Fp8MoeBackend.DEEPGEMM self.allow_cutlass_block_scaled_grouped_gemm = ( self.fp8_backend == Fp8MoeBackend.CUTLASS_BLOCK_SCALED_GROUPED_GEMM ) ''' ============================= Modify by vllm_mlu ============================= @brief: In mlu, always set self.use_marlin as False. ''' self.use_marlin = False ''' ================== End of MLU Hijack ================== ''' def vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: int | None = None, num_expert_group: int | None = None, global_num_experts: int = -1, expert_map: torch.Tensor | None = None, custom_routing_function: Callable | None = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: torch.Tensor | None = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: torch.Tensor | None = None, logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor: if enable_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None assert isinstance(layer, FusedMoE) ''' ============================= Modify by vllm_mlu ============================= @brief: Use moe_softmax_topk and moe_sigmoid_topk of mlu_ops to implement FusedMoE.select_experts ''' from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts if scoring_func == "softmax": topk_weights, topk_ids = mlu_ops.moe_softmax_topk( router_logits, top_k, renormalize, num_expert_group, topk_group, route_scale=routed_scaling_factor, ) elif scoring_func == "sigmoid": topk_weights, topk_ids = mlu_ops.moe_sigmoid_topk( router_logits, top_k, renormalize, num_expert_group, topk_group, routed_scaling_factor, e_score_correction_bias, ) else: raise ValueError(f"Unsupported scoring function: {scoring_func}") # gen_idx ori_input_shape = x.shape x = x.reshape(-1, x.size(-1)) router_logits = router_logits.reshape(-1, router_logits.size(-1)) expert_num = router_logits.size(-1) tokens_num = x.size(0) expert_size = layer.w13_weight.size(0) expand_idx, combine_idx, token_count, cumsum_token_count = mlu_ops.moe_gen_idx( topk_ids, expert_num ) expand_hidden_states = mlu_ops.moe_expand_input( x, expand_idx, cumsum_token_count, 0, expert_size ) quant_input, input_scale = _fp8_quantize( expand_hidden_states, A_scale=None, block_shape=self.quant_config.weight_block_size ) gemm1_out = mlu_ops.smooth_quant_group_gemm( quant_input, layer.w13_weight, token_count, expand_idx=None, c=None, alpha=None, beta=None, a_scale=input_scale.T.contiguous(), b_scale=layer.w13_weight_scale_inv, dtype=x.dtype, max_m=tokens_num, ) act_out = mlu_ops.active(gemm1_out, activation, is_gated=True) act_out_quantize, act_out_scale = _fp8_quantize( act_out, A_scale=None, block_shape=self.quant_config.weight_block_size ) gemm2_out = mlu_ops.smooth_quant_group_gemm( act_out_quantize, layer.w2_weight, token_count, expand_idx=None, c=None, alpha=None, beta=None, a_scale=act_out_scale.T.contiguous(), b_scale=layer.w2_weight_scale_inv, dtype=x.dtype, max_m=tokens_num, ) output = mlu_ops.moe_combine_result( gemm2_out, topk_weights, combine_idx, residual=None, cusum_token_count=cumsum_token_count, start_expert_id=0, expert_size=expert_size, bias=None, ) return output.view(ori_input_shape) """ ================== End of MLU Hijack ================== """ MluHijackObject.apply_hijack( Fp8LinearMethod, Fp8LinearMethod.apply, vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__apply ) MluHijackObject.apply_hijack( Fp8Config, Fp8Config.__init__, vllm__model_executor__layers__quantization__fp8__Fp8Config____init__ ) MluHijackObject.apply_hijack( Fp8Config, Fp8Config.from_config, vllm__model_executor__layers__quantization__fp8__Fp8Config__from_config ) MluHijackObject.apply_hijack( Fp8LinearMethod, Fp8LinearMethod.create_weights, vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__create_weights ) MluHijackObject.apply_hijack( Fp8LinearMethod, Fp8LinearMethod.__init__, vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod____init__ ) MluHijackObject.apply_hijack( Fp8LinearMethod, Fp8LinearMethod.process_weights_after_loading, vllm__model_executor__layers__quantization__fp8__Fp8LinearMethod__process_weights_after_loading ) MluHijackObject.apply_hijack( Fp8MoEMethod, Fp8MoEMethod.__init__, vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod____init__ ) MluHijackObject.apply_hijack( Fp8MoEMethod, Fp8MoEMethod.apply, vllm__model_executor__layers__quantization__fp8__Fp8MoEMethod__apply )