# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase, set_weight_attrs) from vllm.model_executor.layers.quantization import register_quantization_config from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm_mlu import _mlu_ops as mlu_ops from vllm.logger import init_logger logger = init_logger(__name__) # @register_quantization_config("weightonly") class WeightOnlyConfig(QuantizationConfig): """Config class for WeightOnly. """ def __init__( self, weight_bits: int, quant_mode: str, # weight_only ) -> None: super().__init__() self.weight_bits = weight_bits self.quant_mode = quant_mode if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4): raise ValueError( "Currently, only 8/4-bit weight quantization is supported for " f"weight_only, but got {self.weight_bits} bits.") self.pack_factor = 8 // self.weight_bits def __repr__(self) -> str: return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, " f"quant_mode={self.quant_mode})") def get_name(self) -> str: return "WeightOnly" def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.half, torch.bfloat16] @staticmethod def get_config_filenames() -> List[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig": weight_bits = cls.get_from_keys(config, ["bits"]) try: quant_mode = cls.get_from_keys(config, ["quant_mode"]) except Exception: quant_mode = "WeightOnly" return cls(weight_bits, quant_mode) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["WeightOnlyLinearMethod"]: if isinstance(layer, LinearBase): return WeightOnlyLinearMethod(self) return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] class WeightOnlyLinearMethod(LinearMethodBase): """Linear method for WeightOnly. Args: quant_config: The WeightOnly quantization config. """ def __init__(self, quant_config: WeightOnlyConfig): self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> Dict[str, Any]: output_size_per_partition = sum(output_partition_sizes) if self.quant_config.quant_mode == "WeightOnly": scale_and_zero_input_dim = None if output_size != output_size_per_partition: scale_and_zero_input_dim = 0 qweight = Parameter( torch.empty( output_size_per_partition, input_size_per_partition // self.quant_config.pack_factor, device="mlu", dtype=torch.int8, ), requires_grad=False, ) set_weight_attrs(qweight, { "input_dim": 1, "output_dim": 0, }) scales = Parameter( torch.empty( output_size_per_partition, device="mlu", dtype=params_dtype, ), requires_grad=False, ) set_weight_attrs(scales, { "input_dim": scale_and_zero_input_dim, "output_dim": 0, }) layer.register_parameter("qweight", qweight) set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if layer.scales.dtype != torch.float: layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, residual: Optional[torch.Tensor] = None) -> torch.Tensor: x_shape = x.shape if len(x_shape) > 2: x = x.view(-1, x_shape[-1]) out = mlu_ops.weight_only_quant_matmul(x, layer.qweight, layer.scales, None, bias, residual, "none", self.quant_config.weight_bits) if len(x_shape) > 2: out = out.view(*x_shape[:-1], out.shape[-1]) return out