# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Optional import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, int8_w8a16_moe_quant_config, ) from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( UnquantizedFusedMoEMethod, ) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supports_layer, ) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" def __init__( self, linear_quant_method: str, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool, modules_to_not_convert: list[str] | None, full_config: dict[str, Any], ) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.has_zp = has_zp self.bit8_pack_factor = 8 // self.weight_bits self.lm_head_quantized = lm_head_quantized self.linear_quant_method = linear_quant_method self.full_config = full_config self.use_marlin = False # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinConfig if self.linear_quant_method == "gptq": self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config) elif self.linear_quant_method in ("awq", "awq_marlin"): capability_tuple = current_platform.get_device_capability() device_capability = ( -1 if capability_tuple is None else capability_tuple.to_int() ) awq_min_capability = AWQConfig.get_min_capability() if device_capability < awq_min_capability: raise ValueError( "The quantization method moe_wna16 + awq is not supported " "for the current GPU. " f"Minimum capability: {awq_min_capability}. " f"Current capability: {device_capability}." ) self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible(full_config) else: raise ValueError("moe_wna16 only support gptq and awq.") if modules_to_not_convert is None: self.modules_to_not_convert = [] else: self.modules_to_not_convert = modules_to_not_convert @classmethod def get_name(cls) -> QuantizationMethods: return "moe_wna16" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 70 @classmethod def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) if linear_quant_method == "gptq": has_zp = not cls.get_from_keys(config, ["sym"]) modules_to_not_convert = [] elif linear_quant_method in ("awq", "awq_marlin"): has_zp = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None ) else: raise ValueError("moe_wna16 only support gptq and awq.") return cls( linear_quant_method, weight_bits, group_size, has_zp, lm_head_quantized, modules_to_not_convert, config, ) @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant ) -> QuantizationMethods | None: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() return None @classmethod def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") desc_act = quant_config.get("desc_act") capability_tuple = current_platform.get_device_capability() device_capability = ( -1 if capability_tuple is None else capability_tuple.to_int() ) # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8] awq_compatible = ( quant_method == "awq" and num_bits == 4 and device_capability >= awq_min_capability ) return gptq_compatible or awq_compatible def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): if isinstance(layer, FusedMoE): return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig, ) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig, ) if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( self.full_config ).get_quant_method(layer, prefix) else: return GPTQConfig.from_config(self.full_config).get_quant_method( layer, prefix ) elif self.linear_quant_method in ("awq", "awq_marlin"): if self.use_marlin and check_marlin_supports_layer( layer, self.group_size ): return AWQMarlinConfig.from_config( self.full_config ).get_quant_method(layer, prefix) else: return AWQConfig.from_config(self.full_config).get_quant_method( layer, prefix ) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): return MoeWNA16Method(self, layer.moe_config) return None def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]): return any(module_name in prefix for module_name in modules_to_not_convert) class MoeWNA16Method(FusedMoEMethodBase): """Linear method for MOE WNA16 (W8A16/W4A16) quantization. Args: quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ def __init__(self, quant_config: MoeWNA16Config, moe: "FusedMoEConfig") -> None: super().__init__(moe) self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size group_size_div_factor = 1 # make intermediate_size and hidden_size divisible by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later while intermediate_size_per_partition % group_size or hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 layer.group_size = group_size layer.group_size_div_factor = group_size_div_factor strategy = FusedMoeWeightScaleSupported.GROUP.value extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False}) assert "weight_loader" in extra_weight_attrs weight_loader = extra_weight_attrs["weight_loader"] wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader) extra_weight_attrs["weight_loader"] = wrapped_weight_loader # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size // bit8_pack_factor, dtype=torch.uint8, ), requires_grad=False, ) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) w2_qweight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size_per_partition // bit8_pack_factor, dtype=torch.uint8, ), requires_grad=False, ) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) w13_scales = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // group_size, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) w2_scales = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // group_size, dtype=params_dtype, ), requires_grad=False, ) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) if self.quant_config.has_zp: w13_qzeros = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition // bit8_pack_factor, hidden_size // group_size, dtype=torch.uint8, ), requires_grad=False, ) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) w2_qzeros = torch.nn.Parameter( torch.zeros( num_experts, hidden_size // bit8_pack_factor, intermediate_size_per_partition // group_size, dtype=torch.uint8, ), requires_grad=False, ) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) if self.quant_config.linear_quant_method == "gptq": # some param are unused, but we need to init them in order to # load weights invalid_param_keys = ["w13_g_idx", "w2_g_idx"] if not self.quant_config.has_zp: invalid_param_keys += ["w13_qzeros", "w2_qzeros"] for key in invalid_param_keys: param = torch.nn.Parameter( torch.empty((0,), dtype=torch.int32), requires_grad=False ) layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp assert weight_bits == 4 or weight_bits == 8 config_builder = ( int4_w4a16_moe_quant_config if weight_bits == 4 else int8_w8a16_moe_quant_config ) return config_builder( w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, w1_zp=layer.w13_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], ) def apply( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts assert layer.activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids, _ = layer.select_experts( hidden_states=x, router_logits=router_logits, ) return fused_experts( x, layer.w13_qweight, layer.w2_qweight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, quant_config=self.moe_quant_config, ) @staticmethod def get_weight_loader(layer, weight_loader): def convert_awq_tensor(tensor, tensor_type): # convert awq qweight/qzeros to a standard format (assume int4) # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) # qzeros: (k // group_size, n // pack_factor_bit32) -> # (n // pack_factor_bit8, k // group_size) # pack_factor_bit32 = 32 // weight_bits # pack_factor_bit8 = 8 // weight_bits # 0. suppose origin shape (a, b), dtype int32 # 1. convert to uint8, shape (a, b) -> (a, 4 * b) size0 = tensor.size(0) tensor = tensor.view(torch.uint8) # 2. unpack to uint4 (only when weight_bits == 4) # shape (a, 4 * b) -> (a, 4 * b, 2) shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF # 3. change order, see # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py # shape -> (a, 4 * b * pack_factor_bit8) reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] tensor = tensor.view(size0, -1) # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) tensor = tensor.T.contiguous() # 5. repack (only when weight_bits == 4) # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) # qzeros shape -> (4 * b, a) if tensor_type == "qweight": tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] elif tensor_type == "qzeros": tensor = tensor[1::2, :] * 16 + tensor[::2, :] return tensor def convert_gptq_int4_qzeros(tensor): tensor = tensor.view(torch.uint8) shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF tensor = tensor + 1 tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 return tensor def moe_wna16_weight_loader( param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int, return_success: bool = False, ): if "g_idx" in weight_name: return False if return_success else None if not layer.quant_config.has_zp and "qzeros" in weight_name: return False if return_success else None device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() loaded_weight = loaded_weight.to(device) shard_size = layer.intermediate_size_per_partition # convert gptq and awq weight to a standard format # awq_marlin uses the same weight format as awq if layer.quant_config.linear_quant_method in ("awq", "awq_marlin"): assert layer.quant_config.weight_bits == 4 if "weight" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qweight") elif "zeros" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") else: loaded_weight = loaded_weight.T elif layer.quant_config.linear_quant_method == "gptq": assert layer.quant_config.weight_bits in [4, 8] if "weight" in weight_name: loaded_weight = loaded_weight.T.contiguous().view(torch.uint8) elif "zeros" in weight_name: # add 1 to gptq qzeros to align with awq loaded_weight = loaded_weight.view(torch.uint8) if layer.quant_config.weight_bits == 4: loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T else: loaded_weight = loaded_weight.T + 1 else: loaded_weight = loaded_weight.T # repeat the qzeros/scales to fit new group size if ( layer.group_size_div_factor > 1 and "qzeros" in weight_name or "scales" in weight_name ): loaded_weight = loaded_weight.repeat_interleave( layer.group_size_div_factor, 1 ) if "w13_qzeros" in weight_name: tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[ tp_rank ] if shard_id == "w1": param.data[expert_id, : shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2 :] = tensor return True if return_success else None elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( loaded_weight.size(0), layer.tp_size, -1 )[:, tp_rank] return True if return_success else None else: # Delegate to the original loader, passing return_success return weight_loader( param, loaded_weight, weight_name, shard_id, expert_id, return_success=return_success, ) return moe_wna16_weight_loader