Files

338 lines
16 KiB
Python
Raw Permalink Normal View History

2026-04-24 09:50:34 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import (LinearMethodBase, LinearBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import register_quantization_config
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
RowvLLMParameter)
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.model_executor.layers.quantization.utils.common_utils import (str_dtype_to_torch,
str_dtype_to_bits,
is_fp8_str_dtype)
# @register_quantization_config("smoothquant")
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant.
"""
def __init__(
self,
quant_mode: str, # smoothquant
input_quant_method: str, # per token/per tensor
group_size: int,
weight_precision: str,
activation_precision: str,
only_expert_per_group: bool,
expert_weight_precision: str,
expert_activation_precision: str,
force_use_weightonly_except_expert: bool,
) -> None:
super().__init__()
self.quant_mode = quant_mode
self.input_quant_method = input_quant_method
self.group_size = group_size
self.weight_precision = weight_precision
self.activation_precision = activation_precision
self.only_expert_per_group = only_expert_per_group
self.expert_weight_precision = expert_weight_precision
self.expert_activation_precision = expert_activation_precision
self.force_use_weightonly_except_expert = force_use_weightonly_except_expert
if quant_mode == "SmoothQuant" and (self.input_quant_method != "per_token" and self.input_quant_method != "per_tensor"):
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
self.weight_bits = str_dtype_to_bits(self.weight_precision)
self.expert_weight_bits = str_dtype_to_bits(self.expert_weight_precision)
if self.weight_precision == 'int4':
self.weight_dtype = torch.int8
else:
self.weight_dtype = str_dtype_to_torch(self.weight_precision)
if self.expert_weight_precision == 'int4':
self.expert_weight_dtype = torch.int8
else:
self.expert_weight_dtype = str_dtype_to_torch(self.expert_weight_precision)
self.is_fp8 = is_fp8_str_dtype(self.weight_precision)
self.expert_is_fp8 = is_fp8_str_dtype(self.expert_weight_precision)
self.pack_factor = 8 // self.weight_bits
self.expert_pack_factor = 8 // self.expert_weight_bits
def __repr__(self) -> str:
return (f"SmoothQuantConfig(input_quant_method={self.input_quant_method}, "
f"quant_mode={self.quant_mode}, "
f"group_size={self.group_size}, "
f"weight_precision={self.weight_precision}, "
f"activation_precision={self.activation_precision}, "
f"only_expert_per_group={self.only_expert_per_group}, "
f"expert_weight_precision={self.expert_weight_precision}, "
f"expert_activation_precision={self.expert_activation_precision}, "
f"force_use_weightonly_except_expert={self.force_use_weightonly_except_expert})")
@classmethod
def get_name(self) -> str:
return "SmoothQuant"
@classmethod
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@staticmethod
def get_config_filenames() -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
quant_mode = cls.get_from_keys(config, ["quant_mode"])
input_quant_method = cls.get_from_keys(config, ["input_quant_method"])
group_size = cls.get_from_keys_or(config, ["group_size"], 1)
weight_precision = cls.get_from_keys_or(config, ["weight_precision"], "int8")
activation_precision = cls.get_from_keys_or(config, ["activation_precision"], "int8")
only_expert_per_group = cls.get_from_keys_or(config, ["only_expert_per_group"], False)
expert_weight_precision = cls.get_from_keys_or(config, ["expert_weight_precision"], None)
expert_activation_precision = cls.get_from_keys_or(config, ["expert_activation_precision"], None)
force_use_weightonly_except_expert = cls.get_from_keys_or(config, ["force_use_weightonly_except_expert"], False)
if expert_weight_precision is None:
expert_weight_precision = weight_precision
if group_size > 1 and only_expert_per_group and weight_precision == 'int4':
weight_precision = 'int8'
if expert_activation_precision is None:
expert_activation_precision = activation_precision
return cls(quant_mode=quant_mode,
input_quant_method=input_quant_method,
group_size=group_size,
weight_precision=weight_precision,
activation_precision=activation_precision,
only_expert_per_group=only_expert_per_group,
expert_weight_precision=expert_weight_precision,
expert_activation_precision=expert_activation_precision,
force_use_weightonly_except_expert=force_use_weightonly_except_expert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["SmoothQuantLinearMethod"]:
if isinstance(layer, LinearBase):
return SmoothQuantLinearMethod(self, prefix)
return None
def get_scaled_act_names(self) -> List[str]:
return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"]
class SmoothQuantLinearMethod(LinearMethodBase):
"""Linear method for SmoothQuant.
Args:
quant_config: The SmoothQuant quantization config.
"""
def __init__(self, quant_config: SmoothQuantConfig, prefix: str):
self.quant_config = quant_config
# for per-tensor case, we can skip quant input for the first attn|ffn linear
# and fusion this step in layernorm to get better performance
self.skip_quant_input = False
self.compute_dtype = torch.get_default_dtype()
self.is_expert = 'expert' in prefix and "shared_expert" not in prefix
self.weight_dtype = quant_config.expert_weight_dtype if self.is_expert else quant_config.weight_dtype
self.pack_factor = quant_config.expert_pack_factor if self.is_expert else quant_config.pack_factor
self.is_fp8 = quant_config.expert_is_fp8 if self.is_expert else quant_config.is_fp8
if quant_config.only_expert_per_group and self.is_expert and quant_config.group_size > 1:
self.is_group_quant = True
elif quant_config.only_expert_per_group is False and quant_config.group_size > 1:
self.is_group_quant = True
else:
self.is_group_quant = False
self.has_smooth = self.quant_config.input_quant_method == "per_token" and (
self.quant_config.force_use_weightonly_except_expert is False or self.is_expert)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor != 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
weight_loader = extra_weight_attrs.get("weight_loader")
group_num = 1
if self.is_group_quant:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
f"The input size {input_size_per_partition} is not aligned with the quantized "
f"weight shape. This can be caused by too large "
f"tensor parallel size. group_size: {self.quant_config.group_size}.")
group_num = (input_size + self.quant_config.group_size - 1) // self.quant_config.group_size
if input_size_per_partition != input_size:
group_num = (input_size_per_partition + self.quant_config.group_size - 1) // self.quant_config.group_size
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.pack_factor,
device="mlu",
dtype=self.weight_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
if self.is_group_quant:
per_channel_scale = GroupQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
group_num,
device="mlu",
dtype=torch.float32,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
else:
per_channel_scale = ChannelQuantScaleParameter(
data=torch.empty(
output_size_per_partition,
device="mlu",
dtype=torch.float32,
),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("per_channel_scale", per_channel_scale)
if self.has_smooth:
smooth = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(smooth, {
"ignore_warning": True,
})
layer.register_parameter("smooth", smooth)
if self.quant_config.input_quant_method == "per_tensor":
scale_to_int = RowvLLMParameter(
data=torch.empty(
input_size_per_partition,
device="mlu",
dtype=torch.float32,
),
input_dim=0,
weight_loader=weight_loader,
)
set_weight_attrs(scale_to_int, {
"ignore_warning": True,
})
layer.register_parameter("scale_to_int", scale_to_int)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.has_smooth and layer.smooth.dtype != torch.float:
layer.smooth = layer.smooth.to(torch.float)
if self.quant_config.input_quant_method == "per_tensor" and layer.scale_to_int.dtype != torch.float:
layer.scale_to_int = layer.scale_to_int.to(torch.float)
if layer.per_channel_scale.dtype != torch.float:
layer.per_channel_scale = layer.per_channel_scale.to(torch.float)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.per_channel_scale = Parameter(layer.per_channel_scale.data, requires_grad=False)
if self.has_smooth:
layer.smooth = Parameter(layer.smooth.data, requires_grad=False)
if self.quant_config.input_quant_method == "per_tensor":
layer.scale_to_int = Parameter(layer.scale_to_int.data, requires_grad=False)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
residual: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
use_tp_weight : bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
layer_smooth = layer.smooth if self.has_smooth else None
layer_qweight = layer.qweight
layer_per_channel_scale = layer.per_channel_scale
if use_tp_weight:
if hasattr(layer, 'tp_smooth'):
layer_smooth = layer.tp_smooth
if hasattr(layer, 'tp_qweight'):
layer_qweight = layer.tp_qweight
if hasattr(layer, 'tp_per_channel_scale'):
layer_per_channel_scale = layer.tp_per_channel_scale
quant_input = None
if self.skip_quant_input:
quant_input = x
elif self.quant_config.input_quant_method == "per_token":
if self.is_fp8:
quant_input, input_scale = mlu_ops.scaled_quantize(x,
layer_smooth,
quant_type=self.weight_dtype,
quant_mode='dynamic_per_token')
else:
quant_input, input_scale = mlu_ops.per_token_smooth_quantize(x, layer_smooth, None)
elif self.quant_config.input_quant_method == "per_tensor":
quant_input = mlu_ops.quantize(x, layer.scale_to_int, None)
else:
raise ValueError(
"Currently, only per_token or per_tensor input quantization is supported for "
f"SmoothQuant, but got {self.input_quant_method}.")
quant_input_shape = quant_input.shape
if len(quant_input_shape) > 2:
quant_input = quant_input.view(-1, quant_input_shape[-1])
input_scale = input_scale.view(-1)
if residual is not None and len(residual.shape) > 2:
residual = residual.view(-1, residual.shape[-1])
if self.is_fp8:
out = mlu_ops.scaled_matmul(quant_input, layer_qweight, input_scale,
layer_per_channel_scale,
self.compute_dtype if hasattr(self, 'compute_dtype') else x.dtype,
bias,
c=residual, act_mode="none",quant_bit_size=8,
alpha=1.0, beta=1.0, use_hp_active=False,
a_quant_bit_size=8, a_calib=None, b_calib=None)
if output is not None:
out = out.view(output.shape)
output.copy_(out)
out = output
else:
if output is not None:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual, output=output)
else:
out = mlu_ops.smooth_quant_matmul(quant_input, input_scale, layer_qweight,
layer_per_channel_scale, self.compute_dtype, bias, residual)
if len(quant_input_shape) > 2:
out = out.view(*quant_input_shape[:-1], out.shape[-1])
return out