151 lines
5.3 KiB
Python
151 lines
5.3 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch.nn.parameter import Parameter
|
||
|
|
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||
|
|
set_weight_attrs)
|
||
|
|
from vllm.model_executor.layers.quantization import register_quantization_config
|
||
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||
|
|
|
||
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
||
|
|
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
# @register_quantization_config("weightonly")
|
||
|
|
class WeightOnlyConfig(QuantizationConfig):
|
||
|
|
"""Config class for WeightOnly.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
weight_bits: int,
|
||
|
|
quant_mode: str, # weight_only
|
||
|
|
) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.weight_bits = weight_bits
|
||
|
|
self.quant_mode = quant_mode
|
||
|
|
|
||
|
|
if quant_mode == "WeightOnly" and (self.weight_bits != 8 and self.weight_bits != 4):
|
||
|
|
raise ValueError(
|
||
|
|
"Currently, only 8/4-bit weight quantization is supported for "
|
||
|
|
f"weight_only, but got {self.weight_bits} bits.")
|
||
|
|
self.pack_factor = 8 // self.weight_bits
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
return (f"WeightOnlyConfig(weight_bits={self.weight_bits}, "
|
||
|
|
f"quant_mode={self.quant_mode})")
|
||
|
|
|
||
|
|
def get_name(self) -> str:
|
||
|
|
return "WeightOnly"
|
||
|
|
|
||
|
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||
|
|
return [torch.half, torch.bfloat16]
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_config_filenames() -> List[str]:
|
||
|
|
return ["quantize_config.json"]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_config(cls, config: Dict[str, Any]) -> "WeightOnlyConfig":
|
||
|
|
weight_bits = cls.get_from_keys(config, ["bits"])
|
||
|
|
try:
|
||
|
|
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||
|
|
except Exception:
|
||
|
|
quant_mode = "WeightOnly"
|
||
|
|
return cls(weight_bits, quant_mode)
|
||
|
|
|
||
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
||
|
|
prefix: str) -> Optional["WeightOnlyLinearMethod"]:
|
||
|
|
if isinstance(layer, LinearBase):
|
||
|
|
return WeightOnlyLinearMethod(self)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_scaled_act_names(self) -> List[str]:
|
||
|
|
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||
|
|
|
||
|
|
|
||
|
|
class WeightOnlyLinearMethod(LinearMethodBase):
|
||
|
|
"""Linear method for WeightOnly.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
quant_config: The WeightOnly quantization config.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, quant_config: WeightOnlyConfig):
|
||
|
|
self.quant_config = quant_config
|
||
|
|
|
||
|
|
def create_weights(
|
||
|
|
self,
|
||
|
|
layer: torch.nn.Module,
|
||
|
|
input_size_per_partition: int,
|
||
|
|
output_partition_sizes: List[int],
|
||
|
|
input_size: int,
|
||
|
|
output_size: int,
|
||
|
|
params_dtype: torch.dtype,
|
||
|
|
**extra_weight_attrs,
|
||
|
|
) -> Dict[str, Any]:
|
||
|
|
output_size_per_partition = sum(output_partition_sizes)
|
||
|
|
if self.quant_config.quant_mode == "WeightOnly":
|
||
|
|
scale_and_zero_input_dim = None
|
||
|
|
if output_size != output_size_per_partition:
|
||
|
|
scale_and_zero_input_dim = 0
|
||
|
|
qweight = Parameter(
|
||
|
|
torch.empty(
|
||
|
|
output_size_per_partition,
|
||
|
|
input_size_per_partition // self.quant_config.pack_factor,
|
||
|
|
device="mlu",
|
||
|
|
dtype=torch.int8,
|
||
|
|
),
|
||
|
|
requires_grad=False,
|
||
|
|
)
|
||
|
|
set_weight_attrs(qweight, {
|
||
|
|
"input_dim": 1,
|
||
|
|
"output_dim": 0,
|
||
|
|
})
|
||
|
|
scales = Parameter(
|
||
|
|
torch.empty(
|
||
|
|
output_size_per_partition,
|
||
|
|
device="mlu",
|
||
|
|
dtype=params_dtype,
|
||
|
|
),
|
||
|
|
requires_grad=False,
|
||
|
|
)
|
||
|
|
set_weight_attrs(scales, {
|
||
|
|
"input_dim": scale_and_zero_input_dim,
|
||
|
|
"output_dim": 0,
|
||
|
|
})
|
||
|
|
layer.register_parameter("qweight", qweight)
|
||
|
|
set_weight_attrs(qweight, extra_weight_attrs)
|
||
|
|
layer.register_parameter("scales", scales)
|
||
|
|
set_weight_attrs(scales, extra_weight_attrs)
|
||
|
|
|
||
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||
|
|
if layer.scales.dtype != torch.float:
|
||
|
|
layer.scales = Parameter(layer.scales.to(torch.float), requires_grad=False)
|
||
|
|
|
||
|
|
def apply(self,
|
||
|
|
layer: torch.nn.Module,
|
||
|
|
x: torch.Tensor,
|
||
|
|
bias: Optional[torch.Tensor] = None,
|
||
|
|
residual: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||
|
|
x_shape = x.shape
|
||
|
|
if len(x_shape) > 2:
|
||
|
|
x = x.view(-1, x_shape[-1])
|
||
|
|
out = mlu_ops.weight_only_quant_matmul(x,
|
||
|
|
layer.qweight,
|
||
|
|
layer.scales,
|
||
|
|
None,
|
||
|
|
bias,
|
||
|
|
residual,
|
||
|
|
"none",
|
||
|
|
self.quant_config.weight_bits)
|
||
|
|
if len(x_shape) > 2:
|
||
|
|
out = out.view(*x_shape[:-1], out.shape[-1])
|
||
|
|
return out
|