# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import Any, Callable, Optional import torch import os from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported, UnquantizedFusedMoEMethod) from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supports_layer) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.model_executor.layers.fused_moe import fused_experts from vllm.model_executor.layers.quantization.awq import ( is_layer_skipped_awq) from lmslim.layers.fused_moe.fuse_moe_int4 import fused_experts_w4a16 os.environ['W4A16_MOE_CUDA'] = os.environ.get('W4A16_MOE_CUDA', '0') os.environ['W4A16_MOE_LMSLIM'] = os.environ.get('W4A16_MOE_LMSLIM', '1') if os.environ['W4A16_MOE_CUDA'] == '1': from vllm.model_executor.layers.quantization.utils.fused_moe_cuda import fused_experts_cuda class MoeWNA16Config(QuantizationConfig): """Config class for MOE WNA16 (W8A16/W4A16) quantization.""" def __init__(self, linear_quant_method: str, weight_bits: int, group_size: int, has_zp: bool, lm_head_quantized: bool, modules_to_not_convert: Optional[list[str]], full_config: dict[str, Any]) -> None: super().__init__() self.weight_bits = weight_bits self.group_size = group_size self.has_zp = has_zp self.bit8_pack_factor = 8 // self.weight_bits self.lm_head_quantized = lm_head_quantized self.linear_quant_method = linear_quant_method self.full_config = full_config self.use_marlin = False # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig) from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) if self.linear_quant_method == "gptq": self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible( full_config) elif self.linear_quant_method == "awq": capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) awq_min_capability = AWQConfig.get_min_capability() if device_capability < awq_min_capability: raise ValueError( "The quantization method moe_wna16 + awq is not supported " "for the current GPU. " f"Minimum capability: {awq_min_capability}. " f"Current capability: {device_capability}.") self.use_marlin = AWQMarlinConfig.is_awq_marlin_compatible( full_config) else: raise ValueError("moe_wna16 only support gptq and awq.") if modules_to_not_convert is None: self.modules_to_not_convert = [] else: self.modules_to_not_convert = modules_to_not_convert @classmethod def get_name(cls) -> QuantizationMethods: return "moe_wna16" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16, torch.half] @classmethod def get_min_capability(cls) -> int: return 70 @classmethod def get_config_filenames(cls) -> list[str]: return ["quantize_config.json"] @classmethod def from_config(cls, config: dict[str, Any]) -> "MoeWNA16Config": linear_quant_method = cls.get_from_keys(config, ["quant_method"]) weight_bits = cls.get_from_keys(config, ["bits"]) group_size = cls.get_from_keys(config, ["group_size"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) if linear_quant_method == "gptq": has_zp = not cls.get_from_keys(config, ["sym"]) modules_to_not_convert = [] elif linear_quant_method == "awq": has_zp = cls.get_from_keys(config, ["zero_point"]) modules_to_not_convert = cls.get_from_keys_or( config, ["modules_to_not_convert"], None) else: raise ValueError("moe_wna16 only support gptq and awq.") return cls(linear_quant_method, weight_bits, group_size, has_zp, lm_head_quantized, modules_to_not_convert, config) @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]: can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) if can_convert and user_quant == "moe_wna16": return cls.get_name() return None @classmethod def is_moe_wna16_compatible(cls, quant_config: dict[str, Any]): # Extract data from quant config. quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") desc_act = quant_config.get("desc_act") capability_tuple = current_platform.get_device_capability() device_capability = (-1 if capability_tuple is None else capability_tuple.to_int()) # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig awq_min_capability = AWQConfig.get_min_capability() gptq_compatible = quant_method == "gptq" and \ not desc_act and num_bits in [4, 8] awq_compatible = quant_method == "awq" and num_bits == 4 and \ device_capability >= awq_min_capability return gptq_compatible or awq_compatible def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["QuantizeMethodBase"]: if isinstance(layer, LinearBase): if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() # Avoid circular import from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import ( AWQMarlinConfig) from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinConfig) if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( self.full_config).get_quant_method(layer, prefix) else: return GPTQConfig.from_config( self.full_config).get_quant_method(layer, prefix) elif self.linear_quant_method == "awq": if self.use_marlin and check_marlin_supports_layer( layer, self.group_size): return AWQMarlinConfig.from_config( self.full_config).get_quant_method(layer, prefix) else: return AWQConfig.from_config( self.full_config).get_quant_method(layer, prefix) else: raise ValueError("moe_wna16 only support gptq and awq.") elif isinstance(layer, FusedMoE): if is_layer_skipped_awq( prefix, getattr(self, "modules_to_not_convert", [])): return UnquantizedFusedMoEMethod(layer.moe_config) return MoeWNA16Method(self) return None def is_layer_skipped_quant(prefix: str, modules_to_not_convert: list[str]): return any(module_name in prefix for module_name in modules_to_not_convert) class MoeWNA16Method(FusedMoEMethodBase): """Linear method for MOE WNA16 (W8A16/W4A16) quantization. Args: quant_config: The MOE WNA16 (W8A16/W4A16) quantization config. """ def __init__(self, quant_config: MoeWNA16Config): self.quant_config = quant_config self.use_w4a16_moe_sz = os.environ.get('AWQ_MOE_SZ') == '1' self.use_w4a16_cuda = 0 self.use_moe_lmslim = 0 if self.use_w4a16_moe_sz: self.use_w4a16_cuda = os.environ['W4A16_MOE_CUDA'] == '1' self.use_moe_lmslim = os.environ['W4A16_MOE_LMSLIM'] == "1" def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size group_size_div_factor = 1 # make intermediate_size and hidden_size diviable by group_size # we reduce the group size to ensure that # and we would repeat the loaded_weight later while intermediate_size_per_partition % group_size or \ hidden_size % group_size: group_size = group_size // 2 group_size_div_factor *= 2 assert group_size >= 32 layer.group_size = group_size layer.group_size_div_factor = group_size_div_factor strategy = FusedMoeWeightScaleSupported.GROUP.value extra_weight_attrs.update({ "quant_method": strategy, "is_transposed": False }) assert 'weight_loader' in extra_weight_attrs weight_loader = extra_weight_attrs['weight_loader'] wrapped_weight_loader = MoeWNA16Method.get_weight_loader( layer, weight_loader) extra_weight_attrs['weight_loader'] = wrapped_weight_loader # Fused gate_up_proj (column parallel) w13_qweight = torch.nn.Parameter(torch.empty( num_experts, 2 * intermediate_size_per_partition, hidden_size // bit8_pack_factor, dtype=torch.uint8), requires_grad=False) layer.register_parameter("w13_qweight", w13_qweight) set_weight_attrs(w13_qweight, extra_weight_attrs) # down_proj (row parallel) w2_qweight = torch.nn.Parameter(torch.empty( num_experts, hidden_size, intermediate_size_per_partition // bit8_pack_factor, dtype=torch.uint8), requires_grad=False) layer.register_parameter("w2_qweight", w2_qweight) set_weight_attrs(w2_qweight, extra_weight_attrs) w13_scales = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // group_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w13_scales", w13_scales) set_weight_attrs(w13_scales, extra_weight_attrs) w2_scales = torch.nn.Parameter(torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // group_size, dtype=params_dtype), requires_grad=False) layer.register_parameter("w2_scales", w2_scales) set_weight_attrs(w2_scales, extra_weight_attrs) if self.quant_config.has_zp: w13_qzeros = torch.nn.Parameter(torch.zeros( num_experts, 2 * intermediate_size_per_partition // bit8_pack_factor, hidden_size // group_size, dtype=torch.uint8), requires_grad=False) layer.register_parameter("w13_qzeros", w13_qzeros) set_weight_attrs(w13_qzeros, extra_weight_attrs) w2_qzeros = torch.nn.Parameter(torch.zeros( num_experts, hidden_size // bit8_pack_factor, intermediate_size_per_partition // group_size, dtype=torch.uint8), requires_grad=False) layer.register_parameter("w2_qzeros", w2_qzeros) set_weight_attrs(w2_qzeros, extra_weight_attrs) if self.quant_config.linear_quant_method == "gptq": # some param are unused, but we need to init them in order to # load weights invalid_param_keys = ["w13_g_idx", "w2_g_idx"] if not self.quant_config.has_zp: invalid_param_keys += ["w13_qzeros", "w2_qzeros"] for key in invalid_param_keys: param = torch.nn.Parameter(torch.empty((0, ), dtype=torch.int32), requires_grad=False) layer.register_parameter(key, param) set_weight_attrs(param, extra_weight_attrs) def restore_qzeros_tensor(self, qzeros, qscales): low_bits = qzeros & 0x0F high_bits = qzeros >> 4 zeors_tensor = torch.stack([low_bits, high_bits], dim=2).view(qzeros.shape[0], -1 , qzeros.shape[-1]) zeors_int16 = zeors_tensor.to(torch.int16) assert zeors_int16.shape == qscales.shape uint16_tensor1 = zeors_int16.view(torch.uint16) uint16_tensor2 = qscales.view(torch.uint16) uint32_tensor1 = uint16_tensor1.to(torch.int32) << 16 uint32_tensor2 = uint16_tensor2.to(torch.int32) result_tensor = uint32_tensor1 + uint32_tensor2 result_tensor =result_tensor.view(torch.uint32) result_tensor = result_tensor.transpose(1, 2).contiguous() return result_tensor def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if self.use_w4a16_moe_sz: sz_tensor_1 = self.restore_qzeros_tensor(layer.w13_qzeros, layer.w13_scales) sz_tensor_2 = self.restore_qzeros_tensor(layer.w2_qzeros, layer.w2_scales) layer.w13_scales = torch.nn.Parameter(sz_tensor_1,requires_grad=False) layer.w2_scales = torch.nn.Parameter(sz_tensor_2,requires_grad=False) layer.w13_qzeros = None layer.w2_qzeros = None torch.cuda.empty_cache() def apply( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, use_nn_moe: Optional[bool] = False, routed_scaling_factor: Optional[float] = None, use_fused_gate: Optional[bool] = False, ) -> torch.Tensor: if enable_eplb: raise NotImplementedError( "EPLB not supported for `MoeWNA16Method` yet.") assert activation == "silu", "Only SiLU activation is supported." topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, routed_scaling_factor=routed_scaling_factor, use_fused_gate=use_fused_gate) weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp if self.use_moe_lmslim: return fused_experts_w4a16( x, layer.w13_qweight, layer.w2_qweight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_int4_w4a16=True, global_num_experts=global_num_experts, expert_map=expert_map, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, block_shape=[0, layer.group_size]) if self.use_w4a16_cuda: m = topk_ids.shape[0] if m <= 512: return fused_experts_cuda(x, layer.w13_qweight, layer.w2_qweight, topk_weights, topk_ids, inplace=True, use_fp8_w8a8=False, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=False, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, w1_zp=None, w2_zp=None, a1_scale=None, a2_scale=None, block_shape=[0, layer.group_size], expert_map=expert_map) return fused_experts( x, layer.w13_qweight, layer.w2_qweight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, use_int4_w4a16=weight_bits == 4, use_int8_w8a16=weight_bits == 8, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, w1_scale=layer.w13_scales, w2_scale=layer.w2_scales, w1_zp=layer.w13_qzeros if has_zp else None, w2_zp=layer.w2_qzeros if has_zp else None, block_shape=[0, layer.group_size], use_nn_moe=False) @staticmethod def get_weight_loader(layer, weight_loader): def convert_awq_tensor(tensor, tensor_type): # convert awq qweight/qzeros to a standard format (assume int4) # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8) # qzeros: (k // group_size, n // pack_factor_bit32) -> # (n // pack_factor_bit8, k // group_size) # pack_factor_bit32 = 32 // weight_bits # pack_factor_bit8 = 8 // weight_bits # 0. suppose origin shape (a, b), dtype int32 # 1. convert to uint8, shape (a, b) -> (a, 4 * b) size0 = tensor.size(0) tensor = tensor.view(torch.uint8) # 2. unpack to uint4 (only when weight_bits == 4) # shape (a, 4 * b) -> (a, 4 * b, 2) shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF # 3. change order, see # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py # shape -> (a, 4 * b * pack_factor_bit8) reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7] tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order] tensor = tensor.view(size0, -1) # 4. transpose, shape -> (4 * b * pack_factor_bit8, a) tensor = tensor.T.contiguous() # 5. repack (only when weight_bits == 4) # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8) # qzeros shape -> (4 * b, a) if tensor_type == "qweight": tensor = tensor[:, 1::2] * 16 + tensor[:, ::2] elif tensor_type == "qzeros": tensor = tensor[1::2, :] * 16 + tensor[::2, :] return tensor def convert_gptq_int4_qzeros(tensor): tensor = tensor.view(torch.uint8) shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device) tensor = (tensor[:, :, None] >> shifter) & 0xF tensor = tensor + 1 tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16 return tensor def moe_wna16_weight_loader(param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int, return_success: bool = False): if "g_idx" in weight_name: return if not layer.quant_config.has_zp and "qzeros" in weight_name: return device = get_tp_group().device tp_rank = get_tensor_model_parallel_rank() loaded_weight = loaded_weight.to(device) shard_size = layer.intermediate_size_per_partition # convert gptq and awq weight to a standard format if layer.quant_config.linear_quant_method == "awq": assert layer.quant_config.weight_bits == 4 if "weight" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qweight") elif "zeros" in weight_name: loaded_weight = convert_awq_tensor(loaded_weight, "qzeros") else: loaded_weight = loaded_weight.T elif layer.quant_config.linear_quant_method == "gptq": assert layer.quant_config.weight_bits in [4, 8] if "weight" in weight_name: loaded_weight = loaded_weight.T.contiguous().view( torch.uint8) elif "zeros" in weight_name: # add 1 to gptq qzeros to align with awq loaded_weight = loaded_weight.view(torch.uint8) if layer.quant_config.weight_bits == 4: loaded_weight = convert_gptq_int4_qzeros( loaded_weight).T else: loaded_weight = loaded_weight.T + 1 else: loaded_weight = loaded_weight.T # repeat the qzeros/scales to fit new group size if layer.group_size_div_factor > 1 and \ "qzeros" in weight_name or "scales" in weight_name: loaded_weight = loaded_weight.repeat_interleave( layer.group_size_div_factor, 1) if "w13_qzeros" in weight_name: tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[tp_rank] if shard_id == "w1": param.data[expert_id, :shard_size // 2] = tensor else: param.data[expert_id, shard_size // 2:] = tensor elif "w2_qzeros" in weight_name: param.data[expert_id] = loaded_weight.view( loaded_weight.size(0), layer.tp_size, -1)[:, tp_rank] else: weight_loader(param, loaded_weight, weight_name, shard_id, expert_id) return_success = True return return_success return moe_wna16_weight_loader