diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index 1da678c58..407437aee 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -9,7 +9,6 @@ import torch try: from vllm.model_executor.layers.quantization.aqlm import AQLMConfig - from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig @@ -33,14 +32,15 @@ except ImportError: class DummyConfig: pass - AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = ( - CompressedTensorsConfig - ) = DummyConfig + AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = ( + DummyConfig + ) DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = ( GPTQMarlin24Config ) = DummyConfig MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig +from sglang.srt.layers.quantization.awq import AWQConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import ( diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py new file mode 100644 index 000000000..7a8af13aa --- /dev/null +++ b/python/sglang/srt/layers/quantization/awq.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +import logging +from typing import Any, Dict, List, Optional, Union + +import torch +from sgl_kernel import awq_dequantize + +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from sglang.srt.layers.quantization.base_config import QuantizationConfig + +logger = logging.getLogger(__name__) + + +def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): + return any(module_name in prefix for module_name in modules_to_not_convert) + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + modules_to_not_convert: Optional[List[str]] = None, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.modules_to_not_convert = modules_to_not_convert or [] + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits." + ) + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return ( + f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point}, " + f"modules_to_not_convert={self.modules_to_not_convert})" + ) + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_name(self) -> str: + return "awq" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + modules_to_not_convert = cls.get_from_keys_or( + config, ["modules_to_not_convert"], None + ) + return cls(weight_bits, group_size, zero_point, modules_to_not_convert) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["LinearMethodBase"]: + + if isinstance(layer, LinearBase): + if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + return UnquantizedLinearMethod() + return AWQLinearMethod(self) + return None + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." + ) + + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size." + ) + + weight_loader = extra_weight_attrs.get("weight_loader") + qweight = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + qzeros = PackedvLLMParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader, + ) + + scales = GroupQuantScaleParameter( + data=torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + input_dim=0, + output_dim=1, + weight_loader=weight_loader, + ) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("qzeros", qzeros) + layer.register_parameter("scales", scales) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = layer.qweight + scales = layer.scales + qzeros = layer.qzeros + pack_factor = self.quant_config.pack_factor + out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) + reshaped_x = x.reshape(-1, x.shape[-1]) + + out = awq_dequantize(qweight, scales, qzeros) + out = torch.matmul(reshaped_x, out) + + if bias is not None: + out.add_(bias) + return out.reshape(out_shape)