Files
enginex-bi_series-vllm/vllm/model_executor/layers/quantization/smoothquant.py
2025-08-07 07:25:16 +00:00

112 lines
3.8 KiB
Python

from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size
class SmoothQuantConfig(QuantizationConfig):
"""Config class for SmoothQuant
Reference: https://github.com/mit-han-lab/smoothquant
"""
def __init__(
self,
weight_bits: int,
quant_type: str = "tensor"
) -> None:
self.weight_bits = weight_bits
self.quant_type = quant_type
if self.weight_bits != 8:
raise ValueError(
"Currently, only w8a8 quantization is supported for "
f"SmoothQuant, but got {self.weight_bits} bits.")
if self.quant_type != "tensor":
raise ValueError(
"Currently, only tensor wise quantization is supported for "
f"SmoothQuant, but got {self.quant_type} type quantization.")
def __repr__(self) -> str:
return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, "
f"quant_type={self.quant_type})")
def get_name(self) -> str:
return "smoothquant"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half, torch.float]
def get_min_capability(self) -> int:
return 70
@staticmethod
def get_config_filenames() -> List[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig":
weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
quant_type = cls.get_from_keys(config, ["quant_type", "q_type"])
return cls(weight_bits, quant_type)
def get_linear_method(self) -> "SmoothLinearMethod":
return SmoothLinearMethod(world_size=get_tensor_model_parallel_world_size())
def get_scaled_act_names(self) -> List[str]:
return []
class SmoothLinearMethod(LinearMethodBase):
def __init__(self, world_size, *args, **kwargs):
super().__init__(*args, **kwargs)
self.apply_dequant_after_row = world_size > 1
self.dtpye = None
def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=torch.int8),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
self.dtpye = params_dtype
return {"weight": weight}
def apply_weights(
self,
weights: Dict[str, torch.Tensor],
x: torch.Tensor,
bias: Optional[torch.Tensor],
scale: Optional[torch.Tensor] = None,
dequant_scale: float = 1.0,
is_row: bool = False,
) -> torch.Tensor:
x_shape = x.shape
x = x.view(-1, x_shape[-1])
weight = weights["weight"]
y = torch.empty((x.shape[0], weight.shape[0]),dtype=torch.int32,device=x.device)
ops.linear_a8_w8_o32_(x, weight, y)
y = y.view(*x_shape[:-1], -1)
if is_row and self.apply_dequant_after_row:
# when tp > 1, duquant first(To improve accuracy?)
out = torch.empty_like(y, dtype=self.dtpye)
ops.dequant(out, y, scale, dequant_scale)
y = out
return y