# SPDX-License-Identifier: Apache-2.0 import enum from enum import Enum from fractions import Fraction from typing import Any, Dict, List, Optional, Union import numpy as np import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, PackedColumnParameter, PackedvLLMParameter, RowvLLMParameter) from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig from vllm.model_executor.layers.quantization.gptq import ExllamaState from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT import math def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig): self.quant_config = quant_config self.scale_k = 1 self.split_num = 4 def int32_to_int4(s0, axis = -2): # 要先拉平 shape[1, n] # 每个int32 拆成8个int4, 8个int32表示, 得到[8, n] # x32(int32) => 32bit => 4bit x 8 x4[8] 4bit # x32 31-28 => x4[7] # x32 27-24 => x4[6] # ... # x32 3-0 => x4[0] # x32[index=0] => x4[7,6,5,4,3,2,1,0] # 4bit转真实数字: # 不是按补码方式 # 1111 => 15 => 7 # 15-8 = 7 # 0101 => 6 =>-2 # 6-8 = -2 # 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A) => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2 # 内存中实际排布为小端模式: # int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2 # int4: 3, -6, -1, -5, 3, 4, 2, -2 s = s0.view(torch.uint32) all = [] for i in range(8): x = 15 << (i*4) # s2 = torch.bitwise_and(x,s) s2 = torch.from_numpy(np.bitwise_and(x, s.numpy())) s3 = s2 / (2 ** (i*4)) s4 = s3.to(torch.int32) # 补码, 结果不对 # s4[s4 > 7] = s4[s4 > 7]-16 # 直接 - 8 结果正确, 范围: -8-7 s4 = s4 - 8 all.append(s4.reshape(1,*s4.shape)) all = torch.concatenate(all, 0) if axis == -2 or axis == 0: # 8,K//8,N => K//8,8,N => K,N all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous() else: # 8,N,K//8 => N,K//8,8 => N,K all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous() return all def dequant_weight(qw, scales, group_size = 128): N = qw.shape[1] int4_to_int32_axis = -2 if TRANSPOSE_GPTQ_WEIGHT: N = qw.shape[0] int4_to_int32_axis = -1 qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16 if TRANSPOSE_GPTQ_WEIGHT: scales = scales.T.contiguous() qweight = qweight.T.contiguous() scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale # print('qweight', qweight.shape, qweight.dtype) # print('scale', scales.shape, scales.dtype) dequant_weight = qweight * scales #dequant return dequant_weight class GPTQConfig(QuantizationConfig): """Config class for GPTQ. Reference: https://arxiv.org/abs/2210.17323 """ @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.half, torch.bfloat16] class GPTQLinearMethod(LinearMethodBase): def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): del output_size # Unused. weight_loader = extra_weight_attrs.get("weight_loader") # if input_size_per_partition % self.quant_config.group_size != 0: # raise ValueError( # "The input size is not aligned with the quantized " # "weight shape. This can be caused by too large " # "tensor parallel size.") output_size_per_partition = sum(output_partition_sizes) if (output_size_per_partition % self.quant_config.pack_factor.numerator != 0): raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size exllama_state = ExllamaState.UNINITIALIZED scale_and_zero_size = input_size // group_size scale_and_zero_input_dim = None if (input_size != input_size_per_partition and self.quant_config.group_size != -1): # For act-order models, we cannot use Exllama for row parallel layer if self.quant_config.desc_act: exllama_state = ExllamaState.UNUSED else: # we need to partition qzeros and scales for exllama kernel scale_and_zero_size = input_size_per_partition // group_size scale_and_zero_input_dim = 0 qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, dtype=torch.int32, ), input_dim=0, output_dim=1, packed_dim=0, packed_factor=self.quant_config.pack_factor, weight_loader=weight_loader) g_idx = RowvLLMParameter(data=torch.tensor( [ i // self.quant_config.group_size for i in range(input_size_per_partition) ], dtype=torch.int32, ), input_dim=0, weight_loader=weight_loader) qzeros_args = { "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), "weight_loader": weight_loader } weight_scale_args = { "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), "weight_loader": weight_loader } if scale_and_zero_input_dim is None: scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args) qzeros = PackedColumnParameter( output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, **qzeros_args) else: scales = GroupQuantScaleParameter(output_dim=1, input_dim=0, **weight_scale_args) qzeros = PackedvLLMParameter( input_dim=0, output_dim=1, packed_dim=1, packed_factor=self.quant_config.pack_factor, **qzeros_args) layer.register_parameter("qweight", qweight) layer.register_parameter("g_idx", g_idx) layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) layer.exllama_state = exllama_state def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # for torch.compile # self.quant_config.weight_bits == 4 if TRANSPOSE_GPTQ_WEIGHT: layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False) layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False) else: layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) layer.qweight = Parameter(layer.qweight.data, requires_grad=False) layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) layer.scales = Parameter(layer.scales.data, requires_grad=False) # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass if layer.exllama_state == ExllamaState.UNINITIALIZED: if self.quant_config.desc_act: layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int) layer.exllama_state = ExllamaState.READY ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits) else: layer.g_idx.data = torch.empty((0, ), dtype=torch.int, device=layer.g_idx.device) def apply(self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N reshaped_x = x.reshape(-1, x.shape[-1]) # print(f"~~~~ start dequant") # import time # start_quant_time = time.time() # weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device) # end_quant_time = time.time() # print(f"~~~~ dequant time: {end_quant_time - start_quant_time}") # if torch.distributed.get_rank() == 0: # print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}") # output = torch.matmul(reshaped_x, weight) # print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape) output = torch.vacc.w4a8_block_int4_matmul( reshaped_x, layer.qweight.transpose(-1, -2), layer.scales.transpose(-1, -2), [1, self.quant_config.group_size // self.scale_k], ) # print("exiting GPTQLinearMethod apply, output shape:", output.shape) # end_gemm_time = time.time() # if torch.distributed.get_rank() == 0: # print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}") if bias is not None: output.add_(bias) return output.reshape(out_shape)