176 lines
7.1 KiB
Python
176 lines
7.1 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
from tqdm import tqdm
|
||
|
|
import os
|
||
|
|
import safetensors
|
||
|
|
|
||
|
|
class SteTernaryQuantizer(nn.Module):
|
||
|
|
def __init__(self, group_size):
|
||
|
|
super().__init__()
|
||
|
|
self.group_size = group_size
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
org_w_shape = x.shape
|
||
|
|
if self.group_size > 0:
|
||
|
|
assert x.shape[-1] % self.group_size == 0
|
||
|
|
x = x.reshape(-1, self.group_size)
|
||
|
|
elif self.group_size == -1:
|
||
|
|
x = x.reshape(-1, x.shape[-1])
|
||
|
|
assert x.dim() == 2
|
||
|
|
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
|
||
|
|
x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
|
||
|
|
assert torch.isnan(x_q).sum() == 0
|
||
|
|
x = x.reshape(org_w_shape)
|
||
|
|
x_q = x_q.reshape(org_w_shape)
|
||
|
|
return x_q
|
||
|
|
|
||
|
|
class SteIntQuantizer(nn.Module):
|
||
|
|
def __init__(self, bit, group_size):
|
||
|
|
super().__init__()
|
||
|
|
self.bit = bit
|
||
|
|
self.group_size = group_size
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
org_w_shape = x.shape
|
||
|
|
if self.group_size > 0:
|
||
|
|
assert org_w_shape[-1] % self.group_size == 0
|
||
|
|
x = x.reshape(-1, self.group_size)
|
||
|
|
elif self.group_size == -1:
|
||
|
|
x = x.reshape(-1, x.shape[-1])
|
||
|
|
|
||
|
|
assert x.dim() == 2
|
||
|
|
|
||
|
|
abs_max_val = x.abs().amax(dim=1, keepdim=True)
|
||
|
|
max_int = 2 ** (self.bit - 1) - 1
|
||
|
|
min_int = - (2 ** (self.bit - 1))
|
||
|
|
scales = abs_max_val.clamp(min=1e-5) / max_int
|
||
|
|
|
||
|
|
assert torch.isnan(scales).sum() == 0
|
||
|
|
|
||
|
|
x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
|
||
|
|
|
||
|
|
assert torch.isnan(x_q).sum() == 0
|
||
|
|
|
||
|
|
x = x.reshape(org_w_shape)
|
||
|
|
x_q = x_q.reshape(org_w_shape)
|
||
|
|
|
||
|
|
return x_q
|
||
|
|
|
||
|
|
class SteInt2Quantizer(nn.Module):
|
||
|
|
def __init__(self, group_size):
|
||
|
|
super().__init__()
|
||
|
|
self.group_size = group_size
|
||
|
|
|
||
|
|
def forward(self, x):
|
||
|
|
org_w_shape = x.shape
|
||
|
|
if self.group_size > 0:
|
||
|
|
assert x.shape[-1] % self.group_size == 0
|
||
|
|
x = x.reshape(-1, self.group_size)
|
||
|
|
elif self.group_size == -1:
|
||
|
|
x = x.reshape(-1, x.shape[-1])
|
||
|
|
|
||
|
|
assert x.dim() == 2
|
||
|
|
|
||
|
|
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
|
||
|
|
x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
|
||
|
|
|
||
|
|
assert torch.isnan(x_q).sum() == 0
|
||
|
|
|
||
|
|
x = x.reshape(org_w_shape)
|
||
|
|
x_q = x_q.reshape(org_w_shape)
|
||
|
|
|
||
|
|
return x_q
|
||
|
|
|
||
|
|
def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
|
||
|
|
"""
|
||
|
|
直接对PyTorch模型bin文件进行量化。
|
||
|
|
|
||
|
|
Args:
|
||
|
|
input_bin_path: 输入模型bin文件路径
|
||
|
|
output_bin_path: 输出量化后的模型bin文件路径
|
||
|
|
quant_type: 量化类型 ("ternary" 或 "int")
|
||
|
|
bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
|
||
|
|
group_size: 量化分组大小
|
||
|
|
device: 运行设备
|
||
|
|
"""
|
||
|
|
print(f"加载模型文件: {input_bin_path}...")
|
||
|
|
if input_bin_path.endswith(".bin"):
|
||
|
|
state_dict = torch.load(input_bin_path, map_location=device)
|
||
|
|
elif input_bin_path.endswith(".safetensors"):
|
||
|
|
state_dict = safetensors.load_file(input_bin_path)
|
||
|
|
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
|
||
|
|
state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
|
||
|
|
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
|
||
|
|
state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
|
||
|
|
else:
|
||
|
|
raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
|
||
|
|
|
||
|
|
print(f"应用 {quant_type} 量化...")
|
||
|
|
if quant_type == "ternary":
|
||
|
|
quantizer = SteTernaryQuantizer(group_size=group_size)
|
||
|
|
elif quant_type == "int":
|
||
|
|
quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
|
||
|
|
elif quant_type == "int2":
|
||
|
|
quantizer = SteInt2Quantizer(group_size=group_size)
|
||
|
|
else:
|
||
|
|
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||
|
|
|
||
|
|
# 统计需要量化的参数数量
|
||
|
|
total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
|
||
|
|
|
||
|
|
# 应用量化
|
||
|
|
with torch.no_grad():
|
||
|
|
for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
|
||
|
|
if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
|
||
|
|
# 对权重进行量化
|
||
|
|
original_weight = param.data.clone()
|
||
|
|
quantized_weight = quantizer(original_weight)
|
||
|
|
state_dict[name] = quantized_weight
|
||
|
|
|
||
|
|
# 打印前几个层的统计信息
|
||
|
|
if total_params > 0:
|
||
|
|
total_params -= 1
|
||
|
|
if total_params > total_params - 5:
|
||
|
|
print(f"层: {name}")
|
||
|
|
print(f" 原始范围: {original_weight.min():.4f} 到 {original_weight.max():.4f}")
|
||
|
|
print(f" 量化后范围: {quantized_weight.min():.4f} 到 {quantized_weight.max():.4f}")
|
||
|
|
print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
|
||
|
|
|
||
|
|
# 保存量化后的模型
|
||
|
|
print(f"保存量化后的模型到: {output_bin_path}...")
|
||
|
|
if output_bin_path.endswith(".bin"):
|
||
|
|
torch.save(state_dict, output_bin_path)
|
||
|
|
elif output_bin_path.endswith(".safetensors"):
|
||
|
|
safetensors.save_file(state_dict, output_bin_path)
|
||
|
|
else:
|
||
|
|
os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
|
||
|
|
output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
|
||
|
|
torch.save(state_dict, output_bin_path)
|
||
|
|
print("完成!")
|
||
|
|
|
||
|
|
def main():
|
||
|
|
import argparse
|
||
|
|
parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
|
||
|
|
parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
|
||
|
|
parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
|
||
|
|
parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
|
||
|
|
parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
|
||
|
|
parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
|
||
|
|
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
|
||
|
|
parser.add_argument("--config", type=str, default="", help="model config file")
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
os.makedirs(args.output, exist_ok=True)
|
||
|
|
quantize_model_bin(
|
||
|
|
input_bin_path=args.input_bin,
|
||
|
|
output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
|
||
|
|
quant_type=args.quant_type,
|
||
|
|
bit=args.bit,
|
||
|
|
group_size=args.group_size,
|
||
|
|
device=args.device
|
||
|
|
)
|
||
|
|
if args.config:
|
||
|
|
os.system(f"cp {args.config}/* {args.output}")
|
||
|
|
print(f"复制{args.config}文件到{args.output}")
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|