338 lines
16 KiB
Python
338 lines
16 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||
|
|
|
||
|
|
from typing import Any, Dict, List, Optional
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from torch.nn.parameter import Parameter
|
||
|
|
|
||
|
|
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
|
||
|
|
set_weight_attrs)
|
||
|
|
from vllm.model_executor.layers.quantization import register_quantization_config
|
||
|
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||
|
|
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||
|
|
GroupQuantScaleParameter,
|
||
|
|
ModelWeightParameter,
|
||
|
|
RowvLLMParameter)
|
||
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
||
|
|
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
|
||
|
|
str_dtype_to_bits,
|
||
|
|
is_fp8_str_dtype)
|
||
|
|
|
||
|
|
|
||
|
|
# @register_quantization_config("smoothquant")
|
||
|
|
class SmoothQuantConfig(QuantizationConfig):
|
||
|
|
"""Config class for SmoothQuant.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
quant_mode: str, # smoothquant
|
||
|
|
input_quant_method: str, # per token/per tensor
|
||
|
|
group_size: int,
|
||
|
|
weight_precision: str,
|
||
|
|
activation_precision: str,
|
||
|
|
only_expert_per_group: bool,
|
||
|
|
expert_weight_precision: str,
|
||
|
|
expert_activation_precision: str,
|
||
|
|
force_use_weightonly_except_expert: bool,
|
||
|
|
) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self.quant_mode = quant_mode
|
||
|
|
self.input_quant_method = input_quant_method
|
||
|
|
self.group_size = group_size
|
||
|
|
self.weight_precision = weight_precision
|
||
|
|
self.activation_precision = activation_precision
|
||
|
|
self.only_expert_per_group = only_expert_per_group
|
||
|
|
self.expert_weight_precision = expert_weight_precision
|
||
|
|
self.expert_activation_precision = expert_activation_precision
|
||
|
|
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
|
||
|
|
|
||
|
|
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
|
||
|
|
raise ValueError(
|
||
|
|
"Currently, only per_token or per_tensor input quantization is supported for "
|
||
|
|
f"SmoothQuant, but got {self.input_quant_method}.")
|
||
|
|
|
||
|
|
self.weight_bits = str_dtype_to_bits(self.weight_precision)
|
||
|
|
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
|
||
|
|
if self.weight_precision == 'int4':
|
||
|
|
self.weight_dtype = torch.int8
|
||
|
|
else:
|
||
|
|
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
|
||
|
|
if self.expert_weight_precision == 'int4':
|
||
|
|
self.expert_weight_dtype = torch.int8
|
||
|
|
else:
|
||
|
|
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
|
||
|
|
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
|
||
|
|
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
|
||
|
|
|
||
|
|
self.pack_factor = 8 // self.weight_bits
|
||
|
|
self.expert_pack_factor = 8 // self.expert_weight_bits
|
||
|
|
|
||
|
|
def __repr__(self) -> str:
|
||
|
|
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
|
||
|
|
f"quant_mode={self.quant_mode}, "
|
||
|
|
f"group_size={self.group_size}, "
|
||
|
|
f"weight_precision={self.weight_precision}, "
|
||
|
|
f"activation_precision={self.activation_precision}, "
|
||
|
|
f"only_expert_per_group={self.only_expert_per_group}, "
|
||
|
|
f"expert_weight_precision={self.expert_weight_precision}, "
|
||
|
|
f"expert_activation_precision={self.expert_activation_precision}, "
|
||
|
|
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_name(self) -> str:
|
||
|
|
return "SmoothQuant"
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||
|
|
return [torch.half, torch.bfloat16]
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def get_config_filenames() -> List[str]:
|
||
|
|
return ["quantize_config.json"]
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
|
||
|
|
quant_mode = cls.get_from_keys(config, ["quant_mode"])
|
||
|
|
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
|
||
|
|
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
|
||
|
|
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
|
||
|
|
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
|
||
|
|
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
|
||
|
|
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
|
||
|
|
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
|
||
|
|
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
|
||
|
|
|
||
|
|
if expert_weight_precision is None:
|
||
|
|
expert_weight_precision = weight_precision
|
||
|
|
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
|
||
|
|
weight_precision = 'int8'
|
||
|
|
|
||
|
|
if expert_activation_precision is None:
|
||
|
|
expert_activation_precision = activation_precision
|
||
|
|
|
||
|
|
return cls(quant_mode=quant_mode,
|
||
|
|
input_quant_method=input_quant_method,
|
||
|
|
group_size=group_size,
|
||
|
|
weight_precision=weight_precision,
|
||
|
|
activation_precision=activation_precision,
|
||
|
|
only_expert_per_group=only_expert_per_group,
|
||
|
|
expert_weight_precision=expert_weight_precision,
|
||
|
|
expert_activation_precision=expert_activation_precision,
|
||
|
|
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
|
||
|
|
|
||
|
|
def get_quant_method(self, layer: torch.nn.Module,
|
||
|
|
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
|
||
|
|
if isinstance(layer, LinearBase):
|
||
|
|
return SmoothQuantLinearMethod(self, prefix)
|
||
|
|
return None
|
||
|
|
|
||
|
|
def get_scaled_act_names(self) -> List[str]:
|
||
|
|
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
|
||
|
|
|
||
|
|
|
||
|
|
class SmoothQuantLinearMethod(LinearMethodBase):
|
||
|
|
"""Linear method for SmoothQuant.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
quant_config: The SmoothQuant quantization config.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
|
||
|
|
self.quant_config = quant_config
|
||
|
|
# for per-tensor case, we can skip quant input for the first attn|ffn linear
|
||
|
|
# and fusion this step in layernorm to get better performance
|
||
|
|
self.skip_quant_input = False
|
||
|
|
self.compute_dtype = torch.get_default_dtype()
|
||
|
|
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
|
||
|
|
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
|
||
|
|
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
|
||
|
|
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
|
||
|
|
|
||
|
|
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
|
||
|
|
self.is_group_quant = True
|
||
|
|
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
|
||
|
|
self.is_group_quant = True
|
||
|
|
else:
|
||
|
|
self.is_group_quant = False
|
||
|
|
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
|
||
|
|
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
|
||
|
|
|
||
|
|
def create_weights(
|
||
|
|
self,
|
||
|
|
layer: torch.nn.Module,
|
||
|
|
input_size_per_partition: int,
|
||
|
|
output_partition_sizes: List[int],
|
||
|
|
input_size: int,
|
||
|
|
output_size: int,
|
||
|
|
params_dtype: torch.dtype,
|
||
|
|
**extra_weight_attrs,
|
||
|
|
):
|
||
|
|
output_size_per_partition = sum(output_partition_sizes)
|
||
|
|
if (output_size_per_partition % self.quant_config.pack_factor != 0):
|
||
|
|
raise ValueError(
|
||
|
|
"The output size is not aligned with the quantized "
|
||
|
|
"weight shape. This can be caused by too large "
|
||
|
|
"tensor parallel size.")
|
||
|
|
|
||
|
|
weight_loader = extra_weight_attrs.get("weight_loader")
|
||
|
|
group_num = 1
|
||
|
|
if self.is_group_quant:
|
||
|
|
if input_size_per_partition % self.quant_config.group_size != 0:
|
||
|
|
raise ValueError(
|
||
|
|
f"The input size {input_size_per_partition} is not aligned with the quantized "
|
||
|
|
f"weight shape. This can be caused by too large "
|
||
|
|
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
|
||
|
|
|
||
|
|
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||
|
|
if input_size_per_partition != input_size:
|
||
|
|
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
|
||
|
|
|
||
|
|
qweight = ModelWeightParameter(
|
||
|
|
data=torch.empty(
|
||
|
|
output_size_per_partition,
|
||
|
|
input_size_per_partition // self.pack_factor,
|
||
|
|
device="mlu",
|
||
|
|
dtype=self.weight_dtype,
|
||
|
|
),
|
||
|
|
input_dim=1,
|
||
|
|
output_dim=0,
|
||
|
|
weight_loader=weight_loader,
|
||
|
|
)
|
||
|
|
if self.is_group_quant:
|
||
|
|
per_channel_scale = GroupQuantScaleParameter(
|
||
|
|
data=torch.empty(
|
||
|
|
output_size_per_partition,
|
||
|
|
group_num,
|
||
|
|
device="mlu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
),
|
||
|
|
input_dim=1,
|
||
|
|
output_dim=0,
|
||
|
|
weight_loader=weight_loader,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
per_channel_scale = ChannelQuantScaleParameter(
|
||
|
|
data=torch.empty(
|
||
|
|
output_size_per_partition,
|
||
|
|
device="mlu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
),
|
||
|
|
output_dim=0,
|
||
|
|
weight_loader=weight_loader,
|
||
|
|
)
|
||
|
|
|
||
|
|
layer.register_parameter("qweight", qweight)
|
||
|
|
layer.register_parameter("per_channel_scale", per_channel_scale)
|
||
|
|
|
||
|
|
if self.has_smooth:
|
||
|
|
smooth = RowvLLMParameter(
|
||
|
|
data=torch.empty(
|
||
|
|
input_size_per_partition,
|
||
|
|
device="mlu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
),
|
||
|
|
input_dim=0,
|
||
|
|
weight_loader=weight_loader,
|
||
|
|
)
|
||
|
|
set_weight_attrs(smooth, {
|
||
|
|
"ignore_warning": True,
|
||
|
|
})
|
||
|
|
layer.register_parameter("smooth", smooth)
|
||
|
|
if self.quant_config.input_quant_method == "per_tensor":
|
||
|
|
scale_to_int = RowvLLMParameter(
|
||
|
|
data=torch.empty(
|
||
|
|
input_size_per_partition,
|
||
|
|
device="mlu",
|
||
|
|
dtype=torch.float32,
|
||
|
|
),
|
||
|
|
input_dim=0,
|
||
|
|
weight_loader=weight_loader,
|
||
|
|
)
|
||
|
|
set_weight_attrs(scale_to_int, {
|
||
|
|
"ignore_warning": True,
|
||
|
|
})
|
||
|
|
layer.register_parameter("scale_to_int", scale_to_int)
|
||
|
|
|
||
|
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||
|
|
if self.has_smooth and layer.smooth.dtype != torch.float:
|
||
|
|
layer.smooth = layer.smooth.to(torch.float)
|
||
|
|
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
|
||
|
|
layer.scale_to_int = layer.scale_to_int.to(torch.float)
|
||
|
|
if layer.per_channel_scale.dtype != torch.float:
|
||
|
|
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
|
||
|
|
|
||
|
|
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
|
||
|
|
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
|
||
|
|
if self.has_smooth:
|
||
|
|
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
|
||
|
|
if self.quant_config.input_quant_method == "per_tensor":
|
||
|
|
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
|
||
|
|
|
||
|
|
def apply(self,
|
||
|
|
layer: torch.nn.Module,
|
||
|
|
x: torch.Tensor,
|
||
|
|
bias: Optional[torch.Tensor] = None,
|
||
|
|
residual: Optional[torch.Tensor] = None,
|
||
|
|
input_scale: Optional[torch.Tensor] = None,
|
||
|
|
use_tp_weight : bool = False,
|
||
|
|
output: Optional[torch.Tensor] = None,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
layer_smooth = layer.smooth if self.has_smooth else None
|
||
|
|
layer_qweight = layer.qweight
|
||
|
|
layer_per_channel_scale = layer.per_channel_scale
|
||
|
|
if use_tp_weight:
|
||
|
|
if hasattr(layer, 'tp_smooth'):
|
||
|
|
layer_smooth = layer.tp_smooth
|
||
|
|
if hasattr(layer, 'tp_qweight'):
|
||
|
|
layer_qweight = layer.tp_qweight
|
||
|
|
if hasattr(layer, 'tp_per_channel_scale'):
|
||
|
|
layer_per_channel_scale = layer.tp_per_channel_scale
|
||
|
|
|
||
|
|
quant_input = None
|
||
|
|
if self.skip_quant_input:
|
||
|
|
quant_input = x
|
||
|
|
elif self.quant_config.input_quant_method == "per_token":
|
||
|
|
if self.is_fp8:
|
||
|
|
quant_input, input_scale = mlu_ops.scaled_quantize(x,
|
||
|
|
layer_smooth,
|
||
|
|
quant_type=self.weight_dtype,
|
||
|
|
quant_mode='dynamic_per_token')
|
||
|
|
else:
|
||
|
|
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
|
||
|
|
elif self.quant_config.input_quant_method == "per_tensor":
|
||
|
|
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
|
||
|
|
else:
|
||
|
|
raise ValueError(
|
||
|
|
"Currently, only per_token or per_tensor input quantization is supported for "
|
||
|
|
f"SmoothQuant, but got {self.input_quant_method}.")
|
||
|
|
quant_input_shape = quant_input.shape
|
||
|
|
if len(quant_input_shape) > 2:
|
||
|
|
quant_input = quant_input.view(-1, quant_input_shape[-1])
|
||
|
|
input_scale = input_scale.view(-1)
|
||
|
|
if residual is not None and len(residual.shape) > 2:
|
||
|
|
residual = residual.view(-1, residual.shape[-1])
|
||
|
|
if self.is_fp8:
|
||
|
|
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
|
||
|
|
layer_per_channel_scale,
|
||
|
|
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
|
||
|
|
bias,
|
||
|
|
c=residual, act_mode="none",quant_bit_size=8,
|
||
|
|
alpha=1.0, beta=1.0, use_hp_active=False,
|
||
|
|
a_quant_bit_size=8, a_calib=None, b_calib=None)
|
||
|
|
if output is not None:
|
||
|
|
out = out.view(output.shape)
|
||
|
|
output.copy_(out)
|
||
|
|
out = output
|
||
|
|
else:
|
||
|
|
if output is not None:
|
||
|
|
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||
|
|
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
|
||
|
|
else:
|
||
|
|
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
|
||
|
|
layer_per_channel_scale, self.compute_dtype, bias, residual)
|
||
|
|
if len(quant_input_shape) > 2:
|
||
|
|
out = out.view(*quant_input_shape[:-1], out.shape[-1])
|
||
|
|
return out
|