初始化项目,由ModelHub XC社区提供模型
Model: openbmb/BitCPM-CANN-0.5B-unquantized Source: Original Platform
This commit is contained in:
176
qat-convert.py
Normal file
176
qat-convert.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
import safetensors
|
||||
|
||||
class SteTernaryQuantizer(nn.Module):
|
||||
def __init__(self, group_size):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert x.shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
assert x.dim() == 2
|
||||
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5))
|
||||
x_q = (torch.clamp(torch.round(x * scales),-1,1) / scales)
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
return x_q
|
||||
|
||||
class SteIntQuantizer(nn.Module):
|
||||
def __init__(self, bit, group_size):
|
||||
super().__init__()
|
||||
self.bit = bit
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert org_w_shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
assert x.dim() == 2
|
||||
|
||||
abs_max_val = x.abs().amax(dim=1, keepdim=True)
|
||||
max_int = 2 ** (self.bit - 1) - 1
|
||||
min_int = - (2 ** (self.bit - 1))
|
||||
scales = abs_max_val.clamp(min=1e-5) / max_int
|
||||
|
||||
assert torch.isnan(scales).sum() == 0
|
||||
|
||||
x_q = (torch.clamp(torch.round(x / scales), min_int, max_int)) * scales
|
||||
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
|
||||
return x_q
|
||||
|
||||
class SteInt2Quantizer(nn.Module):
|
||||
def __init__(self, group_size):
|
||||
super().__init__()
|
||||
self.group_size = group_size
|
||||
|
||||
def forward(self, x):
|
||||
org_w_shape = x.shape
|
||||
if self.group_size > 0:
|
||||
assert x.shape[-1] % self.group_size == 0
|
||||
x = x.reshape(-1, self.group_size)
|
||||
elif self.group_size == -1:
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
assert x.dim() == 2
|
||||
|
||||
scales = 1.0 / (x.abs().mean(dim=1, keepdim=True).clamp_(min=1e-5) * 1)
|
||||
x_q = (torch.clamp(torch.round(x * scales),-2,1) / scales)
|
||||
|
||||
assert torch.isnan(x_q).sum() == 0
|
||||
|
||||
x = x.reshape(org_w_shape)
|
||||
x_q = x_q.reshape(org_w_shape)
|
||||
|
||||
return x_q
|
||||
|
||||
def quantize_model_bin(input_bin_path, output_bin_path, quant_type="ternary", bit=2, group_size=128, device="cuda" if torch.cuda.is_available() else "cpu"):
|
||||
"""
|
||||
直接对PyTorch模型bin文件进行量化。
|
||||
|
||||
Args:
|
||||
input_bin_path: 输入模型bin文件路径
|
||||
output_bin_path: 输出量化后的模型bin文件路径
|
||||
quant_type: 量化类型 ("ternary" 或 "int")
|
||||
bit: 整数量化的位数 (仅在 quant_type="int" 时使用)
|
||||
group_size: 量化分组大小
|
||||
device: 运行设备
|
||||
"""
|
||||
print(f"加载模型文件: {input_bin_path}...")
|
||||
if input_bin_path.endswith(".bin"):
|
||||
state_dict = torch.load(input_bin_path, map_location=device)
|
||||
elif input_bin_path.endswith(".safetensors"):
|
||||
state_dict = safetensors.load_file(input_bin_path)
|
||||
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "pytorch_model.bin")):
|
||||
state_dict = torch.load(os.path.join(input_bin_path, "pytorch_model.bin"), map_location=device)
|
||||
elif os.path.isdir(input_bin_path) and os.path.exists(os.path.join(input_bin_path, "model.safetensors")):
|
||||
state_dict = safetensors.load_file(os.path.join(input_bin_path, "model.safetensors"))
|
||||
else:
|
||||
raise ValueError(f"不支持的模型文件类型: {input_bin_path}")
|
||||
|
||||
print(f"应用 {quant_type} 量化...")
|
||||
if quant_type == "ternary":
|
||||
quantizer = SteTernaryQuantizer(group_size=group_size)
|
||||
elif quant_type == "int":
|
||||
quantizer = SteIntQuantizer(bit=bit, group_size=group_size)
|
||||
elif quant_type == "int2":
|
||||
quantizer = SteInt2Quantizer(group_size=group_size)
|
||||
else:
|
||||
raise ValueError(f"不支持的量化类型: {quant_type}")
|
||||
|
||||
# 统计需要量化的参数数量
|
||||
total_params = sum(1 for k, v in state_dict.items() if ("weight" in k and "layer" in k) or ("fc" in k))
|
||||
|
||||
# 应用量化
|
||||
with torch.no_grad():
|
||||
for name, param in tqdm(state_dict.items(), total=total_params, desc="量化中"):
|
||||
if (("weight" in name and "layer" in name and param.dim() == 2) or ("fc" in name and param.dim() == 2)):
|
||||
# 对权重进行量化
|
||||
original_weight = param.data.clone()
|
||||
quantized_weight = quantizer(original_weight)
|
||||
state_dict[name] = quantized_weight
|
||||
|
||||
# 打印前几个层的统计信息
|
||||
if total_params > 0:
|
||||
total_params -= 1
|
||||
if total_params > total_params - 5:
|
||||
print(f"层: {name}")
|
||||
print(f" 原始范围: {original_weight.min():.4f} 到 {original_weight.max():.4f}")
|
||||
print(f" 量化后范围: {quantized_weight.min():.4f} 到 {quantized_weight.max():.4f}")
|
||||
print(f" 均方误差: {((original_weight - quantized_weight)**2).mean():.8f}")
|
||||
|
||||
# 保存量化后的模型
|
||||
print(f"保存量化后的模型到: {output_bin_path}...")
|
||||
if output_bin_path.endswith(".bin"):
|
||||
torch.save(state_dict, output_bin_path)
|
||||
elif output_bin_path.endswith(".safetensors"):
|
||||
safetensors.save_file(state_dict, output_bin_path)
|
||||
else:
|
||||
os.makedirs(os.path.dirname(output_bin_path), exist_ok=True)
|
||||
output_bin_path = os.path.join(output_bin_path, "pytorch_model.bin")
|
||||
torch.save(state_dict, output_bin_path)
|
||||
print("完成!")
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="量化PyTorch模型bin文件")
|
||||
parser.add_argument("--input_bin", type=str, required=True, help="输入模型bin文件路径")
|
||||
parser.add_argument("--output", type=str, required=True, help="输出量化后的模型bin文件路径")
|
||||
parser.add_argument("--quant_type", type=str, default="ternary", choices=["ternary", "int", "int2"], help="量化类型")
|
||||
parser.add_argument("--bit", type=int, default=2, help="整数量化的位数")
|
||||
parser.add_argument("--group_size", type=int, default=-1, help="量化分组大小")
|
||||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="运行设备")
|
||||
parser.add_argument("--config", type=str, default="", help="model config file")
|
||||
|
||||
args = parser.parse_args()
|
||||
os.makedirs(args.output, exist_ok=True)
|
||||
quantize_model_bin(
|
||||
input_bin_path=args.input_bin,
|
||||
output_bin_path=os.path.join(args.output, "pytorch_model.bin"),
|
||||
quant_type=args.quant_type,
|
||||
bit=args.bit,
|
||||
group_size=args.group_size,
|
||||
device=args.device
|
||||
)
|
||||
if args.config:
|
||||
os.system(f"cp {args.config}/* {args.output}")
|
||||
print(f"复制{args.config}文件到{args.output}")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user