# SPDX-License-Identifier: Apache-2.0 from __future__ import annotations import logging import warnings from typing import Any, Callable, Dict, List, Optional import torch from sglang.srt.layers.linear import LinearBase, set_weight_attrs from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.marlin_utils import ( apply_awq_marlin_linear, awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, verify_marlin_supports_shape, ) from sglang.srt.layers.quantization.scalar_type import scalar_types from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import replace_parameter try: from vllm import _custom_ops as ops warnings.warn( f"Using kernels directly from vllm. This might lead to performance degradation or " f"missing functionalities as certain kernels may not be optimized. " ) except ImportError: ops = None from sglang.srt.utils import is_cuda, is_hip _is_cuda = is_cuda() _is_hip = is_hip() if _is_cuda: from sgl_kernel import awq_dequantize, fused_marlin_moe elif _is_hip: from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_triton as awq_dequantize, ) warnings.warn(f"HIP does not support fused_marlin_moe currently.") else: warnings.warn(f"Only CUDA and HIP support AWQ currently.") logger = logging.getLogger(__name__) def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) class AWQConfig(QuantizationConfig): """Config class for AWQ. Reference: https://arxiv.org/abs/2306.00978 """ def __init__( self, weight_bits: int, group_size: int, zero_point: bool, modules_to_not_convert: Optional[List[str]] = None, ) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.zero_point = zero_point self.modules_to_not_convert = modules_to_not_convert or [] if self.weight_bits != 4: raise ValueError( "Currently, only 4-bit weight quantization is supported for " f"AWQ, but got {self.weight_bits} bits." ) self.pack_factor = 32 // self.weight_bits def __repr__(self) -> str: return ( f"AWQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " f"zero_point={self.zero_point}, " f"modules_to_not_convert={self.modules_to_not_convert})" ) def get_scaled_act_names(self) -> List[str]: return [] def get_name(self) -> str: return "awq" def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.half] @classmethod def get_min_capability(cls) -> int: # The AWQ kernel only supports Turing or newer GPUs. return 75 @staticmethod def get_config_filenames() -> List[str]: return [ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq "quantize_config.json", ] @classmethod def from_config(cls, config: Dict[str, Any]) -> AWQConfig: weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) return cls(weight_bits, group_size, zero_point, modules_to_not_convert) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[LinearMethodBase]: from sglang.srt.layers.linear import LinearBase if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() return AWQLinearMethod(self) return None class AWQMarlinConfig(QuantizationConfig): """Config class for AWQ Marlin""" # num_bits -> type TYPE_MAP = { 4: scalar_types.uint4, 8: scalar_types.uint8, } def __init__( self, weight_bits: int, group_size: int, zero_point: bool, lm_head_quantized: bool, modules_to_not_convert: Optional[list[str]], full_config: dict[str, Any], ) -> None: super().__init__() self.pack_factor = 32 // weight_bits # packed into int32 self.group_size = group_size self.zero_point = zero_point self.lm_head_quantized = lm_head_quantized self.weight_bits = weight_bits self.modules_to_not_convert = modules_to_not_convert or [] self.full_config = full_config if self.weight_bits not in self.TYPE_MAP: raise ValueError( f"Unsupported num_bits = {self.weight_bits}. " f"Supported num_bits = {self.TYPE_MAP.keys()}" ) self.quant_type = self.TYPE_MAP[self.weight_bits] verify_marlin_supported( self.quant_type, group_size=self.group_size, has_zp=self.zero_point ) def __repr__(self) -> str: return ( f"AWQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " f"zero_point={self.zero_point}, " f"lm_head_quantized={self.lm_head_quantized}, " f"modules_to_not_convert={self.modules_to_not_convert})" ) def get_scaled_act_names(self) -> List[str]: return [] @classmethod def get_name(cls) -> str: return "awq_marlin" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> AWQMarlinConfig: weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) zero_point = cls.get_from_keys(config, ["zero_point"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) return cls( weight_bits, group_size, zero_point, lm_head_quantized, modules_to_not_convert, config, ) @classmethod def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" ) if can_convert and is_valid_user_quant: msg = ( "The model is convertible to {} during runtime." " Using {} kernel.".format(cls.get_name(), cls.get_name()) ) logger.info(msg) return cls.get_name() if can_convert and user_quant == "awq": logger.info( "Detected that the model can run with awq_marlin" ", however you specified quantization=awq explicitly," " so forcing awq. Use quantization=awq_marlin for" " faster inference" ) return None def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() # Check if the layer is supported by AWQMarlin. if not check_marlin_supports_layer(layer, self.group_size): logger.warning_once( "Layer '%s' is not supported by AWQMarlin. Falling back to unoptimized AWQ kernels.", # noqa: E501 prefix, ) return AWQConfig.from_config(self.full_config).get_quant_method( layer, prefix ) return AWQMarlinLinearMethod(self) elif isinstance(layer, FusedMoE): from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( f"Layer '{prefix}' is not supported by AWQMoeMarlin. " "Falling back to Moe WNA16 kernels." ) return MoeWNA16Config.from_config(self.full_config).get_quant_method( layer, prefix ) return AWQMoEMethod(self) return None @classmethod def is_awq_marlin_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") zero_point = quant_config.get("zero_point") if not _is_cuda: return False if quant_method != "awq": return False # If we cannot find the info needed in the config, cannot convert. if num_bits is None or group_size is None or zero_point is None: return False if num_bits not in cls.TYPE_MAP: return False return check_marlin_supported( quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point ) class AWQLinearMethod(LinearMethodBase): """Linear method for AWQ. Args: quant_config: The AWQ quantization config. """ def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size." ) output_size_per_partition = sum(output_partition_sizes) if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size." ) weight_loader = extra_weight_attrs.get("weight_loader") qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader, ) qzeros = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader, ) scales = GroupQuantScaleParameter( data=torch.empty( input_size_per_partition // self.quant_config.group_size, output_size_per_partition, dtype=params_dtype, ), input_dim=0, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros pack_factor = self.quant_config.pack_factor out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,) reshaped_x = x.reshape(-1, x.shape[-1]) out = awq_dequantize(qweight, scales, qzeros) out = torch.matmul(reshaped_x, out) if bias is not None: out.add_(bias) return out.reshape(out_shape) class AWQMarlinLinearMethod(LinearMethodBase): """Linear method for AWQ Marlin. Args: quant_config: The AWQ Marlin quantization config. """ def __init__(self, quant_config: AWQMarlinConfig) -> None: self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: list[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: del output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, group_size=group_size, ) qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader, ) num_groups = input_size_per_partition // group_size qzeros = PackedvLLMParameter( data=torch.empty( num_groups, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader, ) scales = GroupQuantScaleParameter( data=torch.empty( num_groups, output_size_per_partition, dtype=params_dtype, ), input_dim=0, output_dim=1, weight_loader=weight_loader, ) layer.register_parameter("qweight", qweight) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.num_groups = num_groups # TODO: Update this docs # Checkpoints are serialized in AutoAWQ format, which is different from the # marlin format. This function is called after the weights are loaded. # Here, we handle the repacking def process_weights_after_loading(self, layer: torch.nn.Module) -> None: device = layer.qweight.device layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) # Allocate marlin workspace layer.workspace = marlin_make_workspace(device) # Repack weights from AWQ format to marlin format. marlin_qweight = ops.awq_marlin_repack( layer.qweight, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits, ) replace_parameter(layer, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. marlin_scales = marlin_permute_scales( layer.scales, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, group_size=self.quant_config.group_size, ) replace_parameter(layer, "scales", marlin_scales) # Permute zero-points from AWQ format to marlin format. marlin_zp = awq_to_marlin_zero_points( layer.qzeros, size_k=layer.num_groups, size_n=layer.output_size_per_partition, num_bits=self.quant_config.quant_type.size_bits, ) replace_parameter(layer, "qzeros", marlin_zp) # Not-used layer.g_idx = marlin_make_empty_g_idx(device) layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return apply_awq_marlin_linear( input=x, weight=layer.qweight, weight_scale=layer.scales, weight_zp=layer.qzeros, g_idx=layer.g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, quant_type=self.quant_config.quant_type, output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, bias=bias, ) class AWQMoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: AWQMarlinConfig): self.quant_config = quant_config if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): # Delay the import to avoid circular dependency from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported extra_weight_attrs.update( { "is_transposed": True, "quant_method": FusedMoeWeightScaleSupported.GROUP.value, } ) w13_qweight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, 2 * intermediate_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), requires_grad=False, ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) w2_qweight = torch.nn.Parameter( torch.empty( num_experts, intermediate_size_per_partition, hidden_size // self.quant_config.pack_factor, dtype=torch.int32, ), requires_grad=False, ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) num_groups_w13 = hidden_size // self.quant_config.group_size num_groups_w2 = intermediate_size_per_partition // self.quant_config.group_size # WEIGHT_SCALES # Allocate 2 scales for w1 and w3 respectively. w13_scales = torch.nn.Parameter( torch.empty( num_experts, num_groups_w13, intermediate_size_per_partition * 2, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) w2_scales = torch.nn.Parameter( torch.empty(num_experts, num_groups_w2, hidden_size, dtype=params_dtype), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) # WEIGHT_ZERO_POINT # Allocate 2 zero points for w1 and w3 respectively. w13_qzeros = torch.nn.Parameter( torch.empty( num_experts, num_groups_w13, 2 * intermediate_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), requires_grad=False, ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) w2_qzeros = torch.nn.Parameter( torch.empty( num_experts, num_groups_w2, hidden_size // self.quant_config.pack_factor, dtype=torch.int32, ), requires_grad=False, ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) device = layer.w13_qweight.device layer.workspace = marlin_make_workspace(device, 4) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts = layer.w13_qweight.shape[0] device = layer.w13_qweight.device layer.w13_g_idx_sort_indices = torch.nn.Parameter( torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) layer.w2_g_idx_sort_indices = torch.nn.Parameter( torch.empty((num_experts, 0), dtype=torch.int32, device=device), requires_grad=False, ) marlin_w13_qweight = ops.awq_marlin_moe_repack( layer.w13_qweight, layer.w13_g_idx_sort_indices, size_k=layer.w13_qweight.shape[1], size_n=layer.w13_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) replace_parameter(layer, "w13_qweight", marlin_w13_qweight) marlin_w2_qweight = ops.awq_marlin_moe_repack( layer.w2_qweight, layer.w2_g_idx_sort_indices, size_k=layer.w2_qweight.shape[1], size_n=layer.w2_qweight.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) replace_parameter(layer, "w2_qweight", marlin_w2_qweight) # hidden_size->intermediate_size marlin_w13_scales = marlin_moe_permute_scales( s=layer.w13_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w13_scales.shape[2], group_size=self.quant_config.group_size, ) replace_parameter(layer, "w13_scales", marlin_w13_scales) marlin_w2_scales = marlin_moe_permute_scales( s=layer.w2_scales, size_k=layer.intermediate_size_per_partition, size_n=layer.w2_scales.shape[2], group_size=self.quant_config.group_size, ) replace_parameter(layer, "w2_scales", marlin_w2_scales) marlin_w13_zp = moe_awq_to_marlin_zero_points( layer.w13_qzeros, size_k=layer.w13_qzeros.shape[1], size_n=layer.w13_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) replace_parameter(layer, "w13_qzeros", marlin_w13_zp) marlin_w2_zp = moe_awq_to_marlin_zero_points( layer.w2_qzeros, size_k=layer.w2_qzeros.shape[1], size_n=layer.w2_qzeros.shape[2] * self.quant_config.pack_factor, num_bits=self.quant_config.weight_bits, ) replace_parameter(layer, "w2_qzeros", marlin_w2_zp) def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, num_fused_shared_experts: int = 0, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: # Delay the import to avoid circular dependency from sglang.srt.layers.moe.topk import select_experts assert activation == "silu", "Only SiLU activation is supported." assert ( scoring_func == "softmax" ), "Only softmax score func is supported for now." # The input must currently be float16 orig_dtype = x.dtype x = x.half() topk_weights, topk_ids = select_experts( hidden_states=x, router_logits=router_logits, top_k=top_k, use_grouped_topk=use_grouped_topk, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, num_fused_shared_experts=num_fused_shared_experts, custom_routing_function=custom_routing_function, correction_bias=correction_bias, routed_scaling_factor=routed_scaling_factor, ) return fused_marlin_moe( x, layer.w13_qweight, layer.w2_qweight, layer.w13_scales, layer.w2_scales, router_logits, topk_weights, topk_ids, sort_indices1=layer.w13_g_idx_sort_indices, sort_indices2=layer.w2_g_idx_sort_indices, w1_zeros=layer.w13_qzeros, w2_zeros=layer.w2_qzeros, num_bits=self.quant_config.weight_bits, ).to(orig_dtype)