Files
2026-04-02 04:55:00 +00:00

283 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
import enum
from enum import Enum
from fractions import Fraction
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
get_linear_quant_method)
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter)
from vllm.model_executor.layers.quantization.gptq import GPTQConfig as GPTQConfigOrig
from vllm.model_executor.layers.quantization.gptq import ExllamaState
from vllm_vacc.vllm.model_executor.models.vars import TRANSPOSE_GPTQ_WEIGHT
import math
def GPTQLinearMethod__init(self, quant_config: GPTQConfigOrig):
self.quant_config = quant_config
self.scale_k = 1
self.split_num = 4
def int32_to_int4(s0, axis = -2):
# 要先拉平 shape[1, n]
# 每个int32 拆成8个int4, 8个int32表示 得到[8, n]
# x32(int32) => 32bit => 4bit x 8 x4[8] 4bit
# x32 31-28 => x4[7]
# x32 27-24 => x4[6]
# ...
# x32 3-0 => x4[0]
# x32[index=0] => x4[7,6,5,4,3,2,1,0]
# 4bit转真实数字
# 不是按补码方式
# 1111 => 15 => 7
# 15-8 = 7
# 0101 => 6 =>-2
# 6-8 = -2
# 0x 6A CB 37 2B (内存中排列 2B 37 CB 6A => B273BCA6 => (-8) => int4: 3, -6, -1, -5, 3, 4, 2, -2
# 内存中实际排布为小端模式:
# int32: 2B 37 CB 6A => 2,11,3,7,12,11,6,10 => (-8) => -6,3, -5,-1, 4,3, -2,2 => 同一字节所在的两个交换得到 3, -6, -1, -5, 3, 4, 2, -2
# int4: 3, -6, -1, -5, 3, 4, 2, -2
s = s0.view(torch.uint32)
all = []
for i in range(8):
x = 15 << (i*4)
# s2 = torch.bitwise_and(x,s)
s2 = torch.from_numpy(np.bitwise_and(x, s.numpy()))
s3 = s2 / (2 ** (i*4))
s4 = s3.to(torch.int32)
# 补码, 结果不对
# s4[s4 > 7] = s4[s4 > 7]-16
# 直接 - 8 结果正确, 范围: -8-7
s4 = s4 - 8
all.append(s4.reshape(1,*s4.shape))
all = torch.concatenate(all, 0)
if axis == -2 or axis == 0:
# 8,K//8,N => K//8,8,N => K,N
all = all.transpose(-2,0).reshape(-1,all.shape[-1]).contiguous()
else:
# 8,N,K//8 => N,K//8,8 => N,K
all = all.permute(1,2,0).reshape(all.shape[-2],-1).contiguous()
return all
def dequant_weight(qw, scales, group_size = 128):
N = qw.shape[1]
int4_to_int32_axis = -2
if TRANSPOSE_GPTQ_WEIGHT:
N = qw.shape[0]
int4_to_int32_axis = -1
qweight = int32_to_int4(qw,int4_to_int32_axis).to(torch.float16) #int32 => 8 int4 +> fp16
if TRANSPOSE_GPTQ_WEIGHT:
scales = scales.T.contiguous()
qweight = qweight.T.contiguous()
scales = torch.concatenate([scales] * group_size, 1).reshape(-1, N) # scale 按 group_size 扩展, 每 group_size 个数共用一个scale
# print('qweight', qweight.shape, qweight.dtype)
# print('scale', scales.shape, scales.dtype)
dequant_weight = qweight * scales #dequant
return dequant_weight
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
class GPTQLinearMethod(LinearMethodBase):
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs.get("weight_loader")
# if input_size_per_partition % self.quant_config.group_size != 0:
# raise ValueError(
# "The input size is not aligned with the quantized "
# "weight shape. This can be caused by too large "
# "tensor parallel size.")
output_size_per_partition = sum(output_partition_sizes)
if (output_size_per_partition % self.quant_config.pack_factor.numerator
!= 0):
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)
g_idx = RowvLLMParameter(data=torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
dtype=torch.int32,
),
input_dim=0,
weight_loader=weight_loader)
qzeros_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
"weight_loader":
weight_loader
}
weight_scale_args = {
"data":
torch.empty(
scale_and_zero_size,
output_size_per_partition,
dtype=params_dtype,
),
"weight_loader":
weight_loader
}
if scale_and_zero_input_dim is None:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
qzeros = PackedColumnParameter(
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)
qzeros = PackedvLLMParameter(
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
**qzeros_args)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.exllama_state = exllama_state
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# for torch.compile
# self.quant_config.weight_bits == 4
if TRANSPOSE_GPTQ_WEIGHT:
layer.qzeros = Parameter(layer.qzeros.data.T.contiguous(), requires_grad=False)
layer.qweight = Parameter(layer.qweight.data.T.contiguous(), requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data.T.contiguous(), requires_grad=False)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if layer.exllama_state == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
layer.exllama_state = ExllamaState.READY
ops.gptq_shuffle(layer.qweight, layer.g_idx,
self.quant_config.weight_bits)
else:
layer.g_idx.data = torch.empty((0, ),
dtype=torch.int,
device=layer.g_idx.device)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[-2 if TRANSPOSE_GPTQ_WEIGHT else -1], ) # M,N
reshaped_x = x.reshape(-1, x.shape[-1])
# print(f"~~~~ start dequant")
# import time
# start_quant_time = time.time()
# weight = dequant_weight(layer.qweight.cpu(), layer.scales.cpu(), self.quant_config.group_size // self.scale_k).to(layer.qweight.device)
# end_quant_time = time.time()
# print(f"~~~~ dequant time: {end_quant_time - start_quant_time}")
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ weight shape: {weight.shape}, dtype: {weight.dtype}")
# output = torch.matmul(reshaped_x, weight)
# print("entering GPTQLinearMethod apply, reshaped_x shape:", reshaped_x.shape, "reshaped_x stride", reshaped_x.stride(), "input_tensor", x.shape, "qweight shape:", layer.qweight.shape, "scales shape:", layer.scales.shape)
output = torch.vacc.w4a8_block_int4_matmul(
reshaped_x,
layer.qweight.transpose(-1, -2),
layer.scales.transpose(-1, -2),
[1, self.quant_config.group_size // self.scale_k],
)
# print("exiting GPTQLinearMethod apply, output shape:", output.shape)
# end_gemm_time = time.time()
# if torch.distributed.get_rank() == 0:
# print(f"~~~~ gemm time: {end_gemm_time - end_quant_time}")
if bias is not None:
output.add_(bias)
return output.reshape(out_shape)