# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase, set_weight_attrs) from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, ModelWeightParameter, RowvLLMParameter) from vllm_mlu import _mlu_ops as mlu_ops from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch, str_dtype_to_bits, is_fp8_str_dtype) # @register_quantization_config("smoothquant") class SmoothQuantConfig(QuantizationConfig): """Config class for SmoothQuant. """ def __init__( self, quant_mode: str, # smoothquant input_quant_method: str, # per token/per tensor group_size: int, weight_precision: str, activation_precision: str, only_expert_per_group: bool, expert_weight_precision: str, expert_activation_precision: str, force_use_weightonly_except_expert: bool, ) -> None: super().__init__() self.quant_mode = quant_mode self.input_quant_method = input_quant_method self.group_size = group_size self.weight_precision = weight_precision self.activation_precision = activation_precision self.only_expert_per_group = only_expert_per_group self.expert_weight_precision = expert_weight_precision self.expert_activation_precision = expert_activation_precision self.force_use_weightonly_except_expert = force_use_weightonly_except_expert if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"): raise ValueError( "Currently, only per_token or per_tensor input quantization is supported for " f"SmoothQuant, but got {self.input_quant_method}.") self.weight_bits = str_dtype_to_bits(self.weight_precision) self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision) if self.weight_precision == 'int4': self.weight_dtype = torch.int8 else: self.weight_dtype = str_dtype_to_torch(self.weight_precision) if self.expert_weight_precision == 'int4': self.expert_weight_dtype = torch.int8 else: self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision) self.is_fp8 = is_fp8_str_dtype(self.weight_precision) self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision) self.pack_factor = 8 // self.weight_bits self.expert_pack_factor = 8 // self.expert_weight_bits def __repr__(self) -> str: return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, " f"quant_mode={self.quant_mode}, " f"group_size={self.group_size}, " f"weight_precision={self.weight_precision}, " f"activation_precision={self.activation_precision}, " f"only_expert_per_group={self.only_expert_per_group}, " f"expert_weight_precision={self.expert_weight_precision}, " f"expert_activation_precision={self.expert_activation_precision}, " f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})") @classmethod def get_name(self) -> str: return "SmoothQuant" @classmethod def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.half, torch.bfloat16] @staticmethod def get_config_filenames() -> List[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": quant_mode = cls.get_from_keys(config, ["quant_mode"]) input_quant_method = cls.get_from_keys(config, ["input_quant_method"]) group_size = cls.get_from_keys_or(config, ["group_size"], 1) weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8") activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8") only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False) expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None) expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None) force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False) if expert_weight_precision is None: expert_weight_precision = weight_precision if group_size > 1 and only_expert_per_group and weight_precision == 'int4': weight_precision = 'int8' if expert_activation_precision is None: expert_activation_precision = activation_precision return cls(quant_mode=quant_mode, input_quant_method=input_quant_method, group_size=group_size, weight_precision=weight_precision, activation_precision=activation_precision, only_expert_per_group=only_expert_per_group, expert_weight_precision=expert_weight_precision, expert_activation_precision=expert_activation_precision, force_use_weightonly_except_expert=force_use_weightonly_except_expert) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["SmoothQuantLinearMethod"]: if isinstance(layer, LinearBase): return SmoothQuantLinearMethod(self, prefix) return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] class SmoothQuantLinearMethod(LinearMethodBase): """Linear method for SmoothQuant. Args: quant_config: The SmoothQuant quantization config. """ def __init__(self, quant_config: SmoothQuantConfig, prefix: str): self.quant_config = quant_config # for per-tensor case, we can skip quant input for the first attn|ffn linear # and fusion this step in layernorm to get better performance self.skip_quant_input = False self.compute_dtype = torch.get_default_dtype() self.is_expert = 'expert' in prefix and "shared_expert" not in prefix self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8 if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1: self.is_group_quant = True elif quant_config.only_expert_per_group is False and quant_config.group_size > 1: self.is_group_quant = True else: self.is_group_quant = False self.has_smooth = self.quant_config.input_quant_method == "per_token" and ( self.quant_config.force_use_weightonly_except_expert is False or self.is_expert) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): output_size_per_partition = sum(output_partition_sizes) if (output_size_per_partition % self.quant_config.pack_factor != 0): raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") weight_loader = extra_weight_attrs.get("weight_loader") group_num = 1 if self.is_group_quant: if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( f"The input size {input_size_per_partition} is not aligned with the quantized " f"weight shape. This can be caused by too large " f"tensor parallel size. group_size: {self.quant_config.group_size}.") group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size if input_size_per_partition != input_size: group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size qweight = ModelWeightParameter( data=torch.empty( output_size_per_partition, input_size_per_partition // self.pack_factor, device="mlu", dtype=self.weight_dtype, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) if self.is_group_quant: per_channel_scale = GroupQuantScaleParameter( data=torch.empty( output_size_per_partition, group_num, device="mlu", dtype=torch.float32, ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) else: per_channel_scale = ChannelQuantScaleParameter( data=torch.empty( output_size_per_partition, device="mlu", dtype=torch.float32, ), output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("qweight", qweight) layer.register_parameter("per_channel_scale", per_channel_scale) if self.has_smooth: smooth = RowvLLMParameter( data=torch.empty( input_size_per_partition, device="mlu", dtype=torch.float32, ), input_dim=0, weight_loader=weight_loader, ) set_weight_attrs(smooth, { "ignore_warning": True, }) layer.register_parameter("smooth", smooth) if self.quant_config.input_quant_method == "per_tensor": scale_to_int = RowvLLMParameter( data=torch.empty( input_size_per_partition, device="mlu", dtype=torch.float32, ), input_dim=0, weight_loader=weight_loader, ) set_weight_attrs(scale_to_int, { "ignore_warning": True, }) layer.register_parameter("scale_to_int", scale_to_int) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.has_smooth and layer.smooth.dtype != torch.float: layer.smooth = layer.smooth.to(torch.float) if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float: layer.scale_to_int = layer.scale_to_int.to(torch.float) if layer.per_channel_scale.dtype != torch.float: layer.per_channel_scale = layer.per_channel_scale.to(torch.float) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False) if self.has_smooth: layer.smooth = Parameter(layer.smooth.data, requires_grad=False) if self.quant_config.input_quant_method == "per_tensor": layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, use_tp_weight : bool = False, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: layer_smooth = layer.smooth if self.has_smooth else None layer_qweight = layer.qweight layer_per_channel_scale = layer.per_channel_scale if use_tp_weight: if hasattr(layer, 'tp_smooth'): layer_smooth = layer.tp_smooth if hasattr(layer, 'tp_qweight'): layer_qweight = layer.tp_qweight if hasattr(layer, 'tp_per_channel_scale'): layer_per_channel_scale = layer.tp_per_channel_scale quant_input = None if self.skip_quant_input: quant_input = x elif self.quant_config.input_quant_method == "per_token": if self.is_fp8: quant_input, input_scale = mlu_ops.scaled_quantize(x, layer_smooth, quant_type=self.weight_dtype, quant_mode='dynamic_per_token') else: quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None) elif self.quant_config.input_quant_method == "per_tensor": quant_input = mlu_ops.quantize(x, layer.scale_to_int, None) else: raise ValueError( "Currently, only per_token or per_tensor input quantization is supported for " f"SmoothQuant, but got {self.input_quant_method}.") quant_input_shape = quant_input.shape if len(quant_input_shape) > 2: quant_input = quant_input.view(-1, quant_input_shape[-1]) input_scale = input_scale.view(-1) if residual is not None and len(residual.shape) > 2: residual = residual.view(-1, residual.shape[-1]) if self.is_fp8: out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale, layer_per_channel_scale, self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype, bias, c=residual, act_mode="none",quant_bit_size=8, alpha=1.0, beta=1.0, use_hp_active=False, a_quant_bit_size=8, a_calib=None, b_calib=None) if output is not None: out = out.view(output.shape) output.copy_(out) out = output else: if output is not None: out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight, layer_per_channel_scale, self.compute_dtype, bias, residual, output=output) else: out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight, layer_per_channel_scale, self.compute_dtype, bias, residual) if len(quant_input_shape) > 2: out = out.view(*quant_input_shape[:-1], out.shape[-1]) return out