Files
sglang/python/sglang/srt/layers/quantization/qoq.py

246 lines
7.9 KiB
Python

from __future__ import annotations
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
ModelWeightParameter,
)
from sglang.srt.layers.quantization.base_config import (
LinearMethodBase,
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import qserve_w4a8_per_chn_gemm, qserve_w4a8_per_group_gemm
QoQ_SUPPORTED_WEIGHT_BITS = [4]
QoQ_SUPPORTED_GROUP_SIZES = [-1, 128]
class QoQConfig(QuantizationConfig):
"""Config class for QoQ Quantization.
- Weight: static, per-channel/group, asymmetric
- Activation: dynamic, per-token, symmetric
Reference: https://arxiv.org/abs/2405.04532
https://github.com/mit-han-lab/omniserve
"""
def __init__(self, weight_bits: int, group_size: int) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
# Verify
if self.weight_bits not in QoQ_SUPPORTED_WEIGHT_BITS:
raise ValueError(
f"QoQ does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {QoQ_SUPPORTED_WEIGHT_BITS} "
"are supported."
)
if self.group_size not in QoQ_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"QoQ does not support group_size = {self.group_size}. "
f"Only group_sizes = {QoQ_SUPPORTED_GROUP_SIZES} "
"are supported."
)
# 4 bits packed into 8 bit datatype.
self.pack_factor = 8 // self.weight_bits
def __repr__(self) -> str:
return "QoQConfig(weight_bits={}, group_size={})".format(
self.weight_bits, self.group_size
)
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_name(cls) -> str:
return "qoq"
@classmethod
def get_config_filenames(cls) -> List[str]:
"""List of filenames to search for in the model directory."""
return [
"quant_config.json",
"quantize_config.json",
]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> QoQConfig:
weight_bits = cls.get_from_keys(config, ["wbits"])
group_size = cls.get_from_keys(config, ["group_size"])
return cls(weight_bits, group_size)
def get_quant_method(
self,
layer: torch.nn.Module,
prefix: str,
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
if isinstance(layer, LinearBase):
return QoQLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
class QoQLinearMethod(LinearMethodBase):
"""Linear method for QoQ.
Args:
quant_config: The QoQ quantization config.
"""
def __init__(self, quant_config: QoQConfig):
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs.get("weight_loader")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % 32 != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by 32."
)
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}."
)
if (
self.quant_config.group_size != -1
and input_size_per_partition % self.quant_config.group_size != 0
):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}."
)
qweight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int8,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
s1_scales = ChannelQuantScaleParameter(
data=torch.empty(output_size_per_partition, dtype=torch.float16),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("s1_scales", s1_scales)
if self.quant_config.group_size == -1:
s1_szeros = ChannelQuantScaleParameter(
data=torch.empty(output_size_per_partition, dtype=torch.float16),
output_dim=0,
weight_loader=weight_loader,
)
layer.register_parameter("s1_szeros", s1_szeros)
else:
s2_scales = GroupQuantScaleParameter(
data=torch.empty(
(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
),
dtype=torch.int8,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("s2_scales", s2_scales)
s2_zeros = GroupQuantScaleParameter(
data=torch.empty(
(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
),
dtype=torch.int8,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("s2_zeros", s2_zeros)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.s1_scales = Parameter(layer.s1_scales.data, requires_grad=False)
if self.quant_config.group_size == -1:
layer.s1_szeros = Parameter(layer.s1_szeros.data, requires_grad=False)
else:
layer.s2_scales = Parameter(layer.s2_scales.data, requires_grad=False)
layer.s2_zeros = Parameter(layer.s2_zeros.data, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
):
assert x.dtype == torch.float16, "QoQ only supports float16 input now"
if self.quant_config.group_size == -1:
x_q, x_scale, x_sum = per_token_quant_int8(
x, scale_dtype=x.dtype, cal_sum=True
)
out = qserve_w4a8_per_chn_gemm(
x_q, layer.qweight, layer.s1_scales, x_scale, layer.s1_szeros, x_sum
)
else:
x_q, x_scale = per_token_quant_int8(x, scale_dtype=x.dtype)
out = qserve_w4a8_per_group_gemm(
x_q,
layer.qweight,
layer.s2_zeros,
layer.s2_scales,
layer.s1_scales,
x_scale,
)
if bias is not None:
out = out + bias
return out