# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Any, Dict, List, Optional, Tuple import torch from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.scalar_type import ScalarType, scalar_types from vllm.logger import init_logger from vllm.platforms import current_platform from vllm_mlu import _mlu_ops as mlu_ops logger = init_logger(__name__) MLU_SUPPORTED_GROUP_SIZES = [64, 128, 256, 512] # We only support gptq and awq over 300 serials and only support int4 and int8 precision def query_mlu_supported_quant_types(has_zp: bool, device_capability: Optional[int] = None ): if device_capability is None: major, minor = current_platform.get_device_capability() device_capability = major * 10 + minor if has_zp: # AWQ style, unsigned + zero-point return [scalar_types.uint4, scalar_types.uint8] else: # GPTQ style, unsigned + symmetric bias return [scalar_types.uint4b8, scalar_types.uint8b128] def check_mlu_supported( quant_type: ScalarType, group_size: Optional[int], has_zp: bool, device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: if device_capability is None: major, minor = current_platform.get_device_capability() device_capability = major * 10 + minor supported_types = query_mlu_supported_quant_types( has_zp, device_capability) if quant_type not in supported_types: return (False, f"Mlu does not support weight_bits = {quant_type}. " f"Only types = {supported_types} " f"are supported (for group_size = {group_size}, " f"device_capability = {device_capability}, zp = {has_zp}).") if (group_size is None or group_size not in MLU_SUPPORTED_GROUP_SIZES): return (False, f"Mlu does not support group_size = {group_size}. " f"Only group_sizes = {MLU_SUPPORTED_GROUP_SIZES} " "are supported.") return True # @register_quantization_config("awq_mlu") class AWQMluConfig(QuantizationConfig): """Config class for AWQMlu. Reference: https://arxiv.org/abs/2306.00978 """ # num_bits -> type TYPE_MAP = { 4: { False: scalar_types.uint4b8, True: scalar_types.uint4, }, 8: { False: scalar_types.uint8b128, True: scalar_types.uint8, } } VERSION = ["gemm"] def __init__( self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, version: str = "gemm", ) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.pack_factor = 32 // self.weight_bits self.version = version self.support_scale_zeros = False if self.weight_bits not in [4, 8]: raise ValueError( "Currently, only 4/8-bit weight quantization is supported for " f"AWQMlu, but got {self.weight_bits} bits.") if self.version not in self.VERSION: raise ValueError( "Currently, only gemm, gemv version is supported for " f"AWQMlu, but got verion:{self.version}.") if self.version in ["gemm"]: self.order_map = {4: [0, 2, 4, 6, 1, 3, 5, 7], 8: [0, 2, 1, 3]} self.reverse_order_map = {4 : [0, 4, 1, 5, 2, 6, 3, 7], 8: [0, 2, 1, 3]} else: self.order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]} self.reverse_order_map = {4: [0, 1, 2, 3, 4, 5, 6, 7], 8: [0, 1, 2, 3]} def __repr__(self) -> str: return (f"AWQMluConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"zero_point={self.zero_point}), " f"lm_head_quantized={self.lm_head_quantized})") @classmethod def get_name(cls) -> str: return "awq_mlu" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.half, torch.bfloat16, torch.float32] @staticmethod def get_config_filenames() -> List[str]: return ["quant_config.json", "quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "AWQMluConfig": weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) version = cls.get_from_keys_or(config, ["version"], default="gemm") return cls(weight_bits, group_size, zero_point, lm_head_quantized, version) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["AWQMluLinearMethod"]: if (isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): return AWQMluLinearMethod(self) return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: can_convert = cls.is_awq_mlu_compatible(hf_quant_cfg) is_valid_user_quant = (user_quant is None or user_quant == "awq" or user_quant == "awq_mlu") if can_convert and is_valid_user_quant: msg = ("The model is convertible to {} during runtime." " Using {} kernel.".format(cls.get_name(), cls.get_name())) logger.info(msg) return cls.get_name() if can_convert and user_quant == "awq": logger.info("Detected that the model can run with awq_mlu" ", however you specified quantization=awq explicitly," " so forcing awq. Use quantization=awq_mlu for" " faster inference") return None @classmethod def is_awq_mlu_compatible(cls, quant_config: Dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits", None) group_size = quant_config.get("group_size", None) has_zp = quant_config.get("zero_point", None) version = quant_config.get("version", "gemm") if quant_method != "awq": return False # If we cannot find the info needed in the config, cannot convert. if (num_bits is None or group_size is None or has_zp is None): return False if num_bits not in cls.TYPE_MAP: return False if version not in cls.VERSION: return False return check_mlu_supported(quant_type=cls.TYPE_MAP[num_bits][has_zp], group_size=group_size, has_zp=has_zp) class AWQMluLinearMethod(LinearMethodBase): """Linear method for AWQMlu. Args: quant_config: The AWQMlu quantization config. """ def __init__(self, quant_config: AWQMluConfig): self.quant_config = quant_config def create_weights(self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") weight_loader = extra_weight_attrs.get("weight_loader") qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader) qzeros = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader) scales = GroupQuantScaleParameter(data=torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition, dtype=params_dtype, ), input_dim=0, output_dim=1, weight_loader=weight_loader) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: packed_qweight, scale_zeros = self.extract_autoawq(layer) if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros): layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) layer.qzeros = None layer.scales = None else: layer.qweight = torch.nn.Parameter(packed_qweight.contiguous(), requires_grad=False) if scale_zeros is not None: layer.qzeros = torch.nn.Parameter(scale_zeros.contiguous(), requires_grad=False) else: layer.qzeros = None layer.scales = torch.nn.Parameter(layer.scales.data.transpose(0, 1).contiguous(), requires_grad=False) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None) -> torch.Tensor: if self.quant_config.zero_point and not self.quant_config.support_scale_zeros: output = mlu_ops.matmul(x, layer.qweight, bias) if residual is not None: output = output + residual else: output = mlu_ops.weight_only_quant_matmul(x, layer.qweight, layer.scales, layer.qzeros, bias, residual, "none", self.quant_config.weight_bits) return output def extract_autoawq(self, layer: torch.nn.Module): qweight = layer.qweight.data qzeros = layer.qzeros.data scales = layer.scales.data bits = self.quant_config.weight_bits group_size = self.quant_config.group_size # Unpack the qweight and qzeros tensors iweight, izeros = self.unpack_awq_int32_into_int8(qweight, qzeros, bits) # Reverse the order of the iweight and izeros tensors iweight, izeros = self.reverse_awq_order(iweight, izeros, bits) # overflow checks iweight = torch.bitwise_and(iweight, (2**bits) - 1) if izeros is not None: izeros = torch.bitwise_and(izeros, (2**bits) - 1) if self.quant_config.zero_point and (not self.quant_config.support_scale_zeros): scales = scales.repeat_interleave(group_size, dim=0) if izeros is not None: izeros = izeros.repeat_interleave(group_size, dim=0) fweight = (iweight - izeros) * scales else: fweight = iweight * scales # transpose [ci, co] -> [co, ci] fweight = fweight.transpose(0, 1) return fweight, None if self.quant_config.zero_point and self.quant_config.support_scale_zeros and izeros is not None: scale_zeros = izeros.to(scales.dtype) * -1 * scales # transpose [ci, co] -> [co, ci] scale_zeros = scale_zeros.transpose(0, 1) else: scale_zeros = None # transpose [ci, co] -> [co, ci] iweight = iweight.to(torch.int8).transpose(0, 1) if bits == 4: higher_bit_tensor = iweight[:, 1::2] lower_bit_tensor = iweight[:, 0::2] packed_qweight = self.combine_low_bits(higher_bit_tensor, lower_bit_tensor) else: packed_qweight = iweight return packed_qweight, scale_zeros def unpack_awq_int32_into_int8(self, qweight: torch.Tensor, qzeros: torch.Tensor, bits: int): shifts = torch.arange(0, 32, bits, device=qweight.device) dtype = torch.int16 if bits == 8 else torch.int8 # unpacking columnwise iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(dtype) iweights = iweights.view(iweights.shape[0], -1) if not self.quant_config.zero_point or self.quant_config.support_scale_zeros: iweights = torch.bitwise_and(iweights - 2**(bits - 1), (2 ** bits) - 1) # unpacking columnwise if qzeros is not None: izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(dtype) izeros = izeros.view(izeros.shape[0], -1) if not self.quant_config.zero_point: izeros = torch.bitwise_and(izeros - 2**(bits - 1), (2 ** bits) - 1) else: izeros = None return iweights, izeros def reverse_awq_order(self, iweights: torch.Tensor, izeros: torch.Tensor, bits: int): reverse_order_tensor = torch.arange(iweights.shape[-1], dtype=torch.int32, device=iweights.device) reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) reverse_order_tensor = reverse_order_tensor[:, self.quant_config.reverse_order_map[bits]] reverse_order_tensor = reverse_order_tensor.view(-1) rweights = iweights[:, reverse_order_tensor] if izeros is not None: rzeros = izeros[:, reverse_order_tensor] return rweights, rzeros def combine_low_bits(self, tensor_a, tensor_b): """ Combine the lower 4 bits of two int8 tensors into a new int8 tensor. Args: tensor_a (torch.Tensor): First tensor of type int8. tensor_b (torch.Tensor): Second tensor of type int8. Returns: torch.Tensor: New tensor of type int8, combining lower 4 bits of tensor_a and tensor_b. """ # 确保输入是 int8 类型 if tensor_a.dtype != torch.int8 or tensor_b.dtype != torch.int8: raise ValueError("Both tensors must be of int8 type.") # 提取每个 tensor 的低4位 low_bits_a = torch.bitwise_and(tensor_a, 0x0F) # 保留 tensor_a 的低4位 low_bits_b = torch.bitwise_and(tensor_b, 0x0F) # 保留 tensor_b 的低4位 # 将 tensor_a 的低4位左移4位 shifted_low_bits_a = low_bits_a << 4 # 组合两个 tensor 的低4位 combined = torch.bitwise_or(shifted_low_bits_a, low_bits_b) return combined