import functools import importlib.util from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter import vllm.envs as envs from vllm import _custom_ops as ops from vllm.distributed import get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( apply_fp8_marlin_linear, prepare_fp8_layer_for_marlin) from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.fp8_utils import ( apply_fp8_block_linear, check_aiter_fp8_linear_support, create_fp8_input_scale, create_fp8_scale_parameter, create_fp8_weight_parameter, expert_weight_is_col_major, maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, process_fp8_weight_tensor_strategy, requant_weight_ue8m0_inplace, validate_fp8_block_shape) # from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( # all_close_1d, apply_fp8_linear, convert_to_channelwise, # cutlass_block_fp8_supported, cutlass_fp8_supported, # normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, # requantize_with_max_scale) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, all_close_1d, convert_to_channelwise, cutlass_block_fp8_supported, cutlass_fp8_supported, maybe_create_device_identity, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter) from vllm.platforms import current_platform from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, is_layer_skipped) from vllm.model_executor.layers.linear import QKVParallelLinear from vllm.utils import has_deep_gemm logger = init_logger(__name__) # has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None def Fp8LinearMethod__init(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() self.out_dtype = torch.get_default_dtype() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if current_platform.is_rocm(): self.use_marlin = False self.weight_block_size = self.quant_config.weight_block_size self.block_quant = self.quant_config.weight_block_size is not None self.act_q_static = self.quant_config.activation_scheme == "static" # Use per-token quantization for better perf if dynamic and cutlass if not self.act_q_static and cutlass_fp8_supported(): self.act_q_group_shape = GroupShape.PER_TOKEN else: self.act_q_group_shape = GroupShape.PER_TENSOR if self.block_quant: self.block_size = self.quant_config.weight_block_size if self.block_quant: # Marlin doesn't support block-wise fp8 self.use_marlin = False self.scale_k = 1 self.scale_n = 1 self.scale_n_prefill = 1 # only for fp8 moe self.fp8_linear = Fp8LinearOp( act_quant_static=self.act_q_static, act_quant_group_shape=self.act_q_group_shape) class Fp8LinearMethod(LinearMethodBase): def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") if self.block_quant: scale_n = extra_weight_attrs.get("scale_n") scale_k = extra_weight_attrs.get("scale_k") if scale_n is not None: self.scale_n = scale_n if scale_k is not None: self.scale_k = scale_k assert self.weight_block_size is not None layer.weight_block_size = self.weight_block_size tp_size = get_tensor_model_parallel_world_size() assert self.quant_config.weight_block_size is not None block_n, block_k = ( self.quant_config.weight_block_size[0] // self.scale_n , self.quant_config.weight_block_size[1] // self.scale_k , ) # Required by row parallel if (tp_size > 1 and input_size // input_size_per_partition == tp_size and input_size_per_partition % block_k != 0): raise ValueError( f"Weight input_size_per_partition = " f"{input_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # Required by column parallel or enabling merged weights if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(output_partition_sizes) > 1: for output_partition_size in output_partition_sizes: if output_partition_size % block_n != 0: raise ValueError( f"Weight output_partition_size = " f"{output_partition_size} is not divisible by " f"weight quantization block_n = {block_n}.") layer.logical_widths = output_partition_sizes layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.orig_dtype = params_dtype # WEIGHT weight_dtype = (torch.float8_e4m3fn if self.quant_config.is_checkpoint_fp8_serialized else params_dtype) weight = ModelWeightParameter(data=torch.empty( output_size_per_partition, input_size_per_partition, dtype=weight_dtype), input_dim=1, output_dim=0, weight_loader=weight_loader) layer.register_parameter("weight", weight) # If checkpoint is serialized fp8, load them. # Otherwise, wait until process_weights_after_loading. if self.quant_config.is_checkpoint_fp8_serialized: # WEIGHT SCALE if not self.block_quant: scale = PerTensorScaleParameter( data=torch.empty(len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", scale) else: assert self.quant_config.activation_scheme == "dynamic" scale = BlockQuantScaleParameter( data=torch.empty( (output_size_per_partition + block_n - 1) // block_n, (input_size_per_partition + block_k - 1) // block_k, dtype=torch.float32, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) scale[:] = torch.finfo(torch.float32).min # The weight_scale_inv name is intentional for deepseekv3 layer.register_parameter("weight_scale_inv", scale) # INPUT ACTIVATION SCALE if self.quant_config.activation_scheme == "static": scale = PerTensorScaleParameter(data=torch.empty( len(output_partition_sizes), dtype=torch.float32), weight_loader=weight_loader) scale[:] = torch.finfo(torch.float32).min layer.register_parameter("input_scale", scale) else: layer.register_parameter("input_scale", None) def process_weights_after_loading(self, layer: Module) -> None: # TODO(rob): refactor block quant into separate class. if self.block_quant: assert self.quant_config.activation_scheme == "dynamic" if current_platform.is_fp8_fnuz(): weight, weight_scale_inv, _ = \ normalize_e4m3fn_to_e4m3fnuz( weight=layer.weight, weight_scale=layer.weight_scale_inv) else: weight = layer.weight.data weight_scale_inv = layer.weight_scale_inv.data if isinstance(layer, QKVParallelLinear): # NOTE: for QKVParallelLinear # weight_scale should be divisible by 8 Dsps shape = weight_scale_inv.shape[0] repeat = 1 while shape % 8 != 0: repeat *= 2 shape = shape * repeat weight_scale_inv = torch.repeat_interleave(weight_scale_inv, repeats=repeat, dim=0) # weight = self._maybe_pad_weight(weight) # if self.block_quant: # maybe_post_process_fp8_weight_block( # layer, self.cutlass_block_fp8_supported) # Torch.compile cannot use Parameter subclasses. layer.weight = Parameter(weight, requires_grad=False) layer.weight_scale_inv = Parameter(weight_scale_inv, requires_grad=False) return def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: if self.use_marlin: return apply_fp8_marlin_linear( input=x, weight=layer.weight, weight_scale=layer.weight_scale, workspace=layer.workspace, size_n=layer.output_size_per_partition, size_k=layer.input_size_per_partition, bias=bias) # Note: lazy import to avoid triton import error. from vllm.model_executor.layers.quantization.utils.fp8_utils import ( apply_w8a8_block_fp8_linear) if self.block_quant: assert self.quant_config.weight_block_size is not None return apply_w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=[layer.weight.shape[0] // layer.weight_scale_inv.shape[0], layer.weight.shape[1] // layer.weight_scale_inv.shape[1]], weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, ) return self.fp8_linear.apply(input=x, weight=layer.weight, weight_scale=layer.weight_scale, out_dtype=self.out_dtype, input_scale=layer.input_scale, bias=bias) # return apply_fp8_linear( # input=x, # weight=layer.weight, # weight_scale=layer.weight_scale, # input_scale=layer.input_scale, # bias=bias, # cutlass_fp8_supported=self.cutlass_fp8_supported, # # Default to using per_token quantization if cutlass is supported # use_per_token_if_dynamic=self.cutlass_fp8_supported) def Fp8MoEMethod_init_(self, quant_config: Fp8Config, layer: torch.nn.Module): self.layer = layer from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts self.quant_config = quant_config self.block_quant = self.quant_config.weight_block_size is not None self.flashinfer_moe_backend = None self.scale_k = 1 self.scale_n = 1 self.scale_n_prefill = 1 # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization self.use_marlin = (not current_platform.has_device_capability(89) or envs.VLLM_TEST_FORCE_FP8_MARLIN) # Disable marlin for rocm if current_platform.is_rocm() or current_platform.is_vacc: self.use_marlin = False # Check for DeepGemm support. self.allow_deep_gemm = False if envs.VLLM_USE_DEEP_GEMM: if not has_deep_gemm(): logger.warning_once("Failed to import DeepGemm kernels.") elif not self.block_quant: logger.warning_once("Model is not block quantized. Not using " " DeepGemm kernels") elif (current_platform.is_cuda() and current_platform.has_device_capability(90)): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") self.allow_deep_gemm = True else: logger.warning_once( "DeepGemm not supported on the current platform.") # Check for CutlassBlockScaledGroupedGemm support. self.allow_cutlass_block_scaled_grouped_gemm = False if not self.block_quant: logger.warning_once("Model is not block quantized. Not using " "CutlassBlockScaledGroupedGemm kernels") elif (current_platform.is_cuda() and current_platform.has_device_capability(100)): logger.info_once( "Using CutlassBlockScaledGroupedGemm kernels for Fp8MoEMethod." ) self.allow_cutlass_block_scaled_grouped_gemm = True else: logger.warning_once( "CutlassBlockScaledGroupedGemm not supported on the current " "platform.") self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore fused_experts, use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm, allow_cutlass_block_scaled_grouped_gemm=( self.allow_cutlass_block_scaled_grouped_gemm)) class Fp8MoEMethod(FusedMoEMethodBase): def create_weights(self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): if self.quant_config.is_checkpoint_fp8_serialized: params_dtype = torch.float8_e4m3fn if self.block_quant: assert self.quant_config.weight_block_size is not None scale_n = extra_weight_attrs.get("scale_n") scale_n_prefill = extra_weight_attrs.get("scale_n_prefill") scale_k = extra_weight_attrs.get("scale_k") if scale_n is not None: self.scale_n = scale_n if scale_k is not None: self.scale_k = scale_k if scale_n_prefill is not None: self.scale_n_prefill = scale_n_prefill if self.quant_config is not None and self.quant_config.weight_block_size is not None: self.gcd_value = self.quant_config.weight_block_size[0] output_size_no_merge = intermediate_size_per_partition #assert isinstance(output_size_no_merge, int), f"merge output size should divded int, valuue is: {output_size_no_merge}" if output_size_no_merge % self.quant_config.weight_block_size[0]: import math gcd_value = math.gcd(output_size_no_merge % self.quant_config.weight_block_size[0], self.quant_config.weight_block_size[0]) self.scale_n =self.scale_n * self.quant_config.weight_block_size[0] // gcd_value self.scale_n_prefill =self.scale_n_prefill * self.quant_config.weight_block_size[0] // gcd_value if hidden_size % self.quant_config.weight_block_size[1]: import math gcd_value = math.gcd(hidden_size % self.quant_config.weight_block_size[1], self.quant_config.weight_block_size[1]) self.scale_k =self.scale_k * self.quant_config.weight_block_size[1] // gcd_value # self.scale_k = self.scale_n # print('output_size_no_merge', output_size_no_merge) # 按 block_size 分core # output_size_no_merge = 384 # block_size = 128: 384 = 3x128 只能分3core x 128 # block_size = 16: 384 = 24x16 8core x (3x16) 可以分到 8core # output_size_no_merge = 512 # block_size = 128: 512 = 4x128 只能分 4core x 128 # block_size = 64: 512 = 8x64 可以分到 8core x 64 # output_size_no_merge = 768 # block_size = 128: 768 = 6x128 只能分 6core x 128 # block_size = 32: 768 = 8x(3x32) 可以分到 8core x (3x32) core_num = 8 min_block_size = 4 block_size_tmp = self.quant_config.weight_block_size[0] // self.scale_n if output_size_no_merge > block_size_tmp and \ output_size_no_merge % block_size_tmp == 0 and \ output_size_no_merge // block_size_tmp < core_num and \ output_size_no_merge % core_num == 0: core_num_old = output_size_no_merge // block_size_tmp import math gcd_value = math.gcd(core_num, core_num_old) new_scale = core_num // gcd_value if block_size_tmp // new_scale >= min_block_size: self.scale_n = new_scale * self.scale_n #print("moe scale n is:", self.scale_n, self.scale_k, intermediate_size_per_partition) tp_size = get_tensor_model_parallel_world_size() if self.scale_n != self.scale_n_prefill: block_n_prefill = self.quant_config.weight_block_size[0] // self.scale_n_prefill block_n, block_k = ( self.quant_config.weight_block_size[0] // self.scale_n, self.quant_config.weight_block_size[1] // self.scale_k, ) # NOTE: To ensure proper alignment of the block-wise quantization # scales, the output_size of the weights for both the gate and up # layers must be divisible by block_n. # Required by column parallel or enabling merged weights if intermediate_size_per_partition % block_n != 0: raise ValueError( f"The output_size of gate's and up's weight = " f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_n = {block_n}.") if (tp_size > 1 and hidden_size % block_k != 0): # Required by row parallel raise ValueError( f"The input_size of down's weight = " f"{intermediate_size_per_partition} is not divisible by " f"weight quantization block_k = {block_k}.") # WEIGHTS w13_weight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # WEIGHT_SCALES if not self.block_quant: # Allocate 2 scales for w1 and w3 respectively. # They will be combined to a single scale after weight loading. w13_weight_scale = torch.nn.Parameter(torch.ones( num_experts, 2, dtype=torch.float32), requires_grad=False) w2_weight_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) else: w13_weight_scale = torch.nn.Parameter( torch.ones( num_experts, 2 * ((intermediate_size_per_partition + block_n - 1) // block_n), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones( num_experts, (hidden_size + block_k - 1) // block_k, (intermediate_size_per_partition + block_n - 1) // block_n, dtype=torch.float32, ), requires_grad=False, ) if self.scale_n != self.scale_n_prefill: w13_weight_scale_prefill = torch.nn.Parameter( torch.ones( num_experts, 2 * ((intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill), (hidden_size + block_k - 1) // block_k, dtype=torch.float32, ), requires_grad=False, ) w2_weight_scale_prefill = torch.nn.Parameter( torch.ones( num_experts, (hidden_size + block_k - 1) // block_k, (intermediate_size_per_partition + block_n_prefill - 1) // block_n_prefill, dtype=torch.float32, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale_inv_prefill", w13_weight_scale_prefill) layer.register_parameter("w2_weight_scale_inv_prefill", w2_weight_scale_prefill) layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) assert self.quant_config.activation_scheme == "dynamic" # Add the quantization method used (per tensor/grouped/channel) # to ensure the weight scales are loaded in properly extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.BLOCK. value} if self.block_quant else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) if self.scale_n != self.scale_n_prefill: set_weight_attrs(w13_weight_scale_prefill, extra_weight_attrs) set_weight_attrs(w2_weight_scale_prefill, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": if not self.quant_config.is_checkpoint_fp8_serialized: raise ValueError( "Found static activation scheme for checkpoint that " "was not serialized fp8.") w13_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) else: layer.w13_input_scale = None layer.w2_input_scale = None def moe_fp8_apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import FusedMoE topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, ) try: from torch_vacc.vacc.custom_ops import fused_experts from vllm_vacc.vllm.model_executor.models.memory.memory_recycling import memory_recycler experts_output = None if memory_recycler is not None: # remove MOE_EXPERT_OUT_BUFFER # experts_output = memory_recycler.MOE_EXPERT_OUT_BUFFER experts_output = memory_recycler.MOE_SHARED_MLP_OUT_BUFFER return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, use_fp8_w8a8=True, w13_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale), a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, decode_with_batch=layer.is_decode and x.shape[0] > 1, output_opt=experts_output ) except Exception as e: print(f"vacc fused_expert run fail, now using unfused ops: {e}") from torch_vacc.vacc.custom_ops_cpu import fused_experts return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, use_fp8_w8a8=True, w13_scale=(layer.w13_weight_scale_inv if self.block_quant else layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale), a13_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, block_shape=self.quant_config.weight_block_size, )