初始化项目,由ModelHub XC社区提供模型
Model: BoscoTheDog/bitnet_b1_58-large_q8_0_gguf Source: Original Platform
This commit is contained in:
48
utils_quant.py
Normal file
48
utils_quant.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def weight_quant(weight, num_bits=1):
|
||||
dtype = weight.dtype
|
||||
weight = weight.float()
|
||||
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||||
result = (weight * s).round().clamp(-1, 1) / s
|
||||
return result.type(dtype)
|
||||
|
||||
|
||||
def activation_quant(x, num_bits=8):
|
||||
dtype = x.dtype
|
||||
x = x.float()
|
||||
Qn = -2 ** (num_bits - 1)
|
||||
Qp = 2 ** (num_bits - 1) - 1
|
||||
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
|
||||
result = (x * s).round().clamp(Qn, Qp) / s
|
||||
return result.type(dtype)
|
||||
|
||||
|
||||
class BitLinear(nn.Linear):
|
||||
|
||||
def __init__(self,
|
||||
*kargs,
|
||||
weight_bits=1,
|
||||
input_bits=8,
|
||||
**kwargs
|
||||
):
|
||||
super(BitLinear, self).__init__(*kargs, **kwargs)
|
||||
"""
|
||||
RMSNorm is placed outside BitLinear
|
||||
"""
|
||||
self.weight_bits = weight_bits
|
||||
self.input_bits = input_bits
|
||||
|
||||
def forward(self, input):
|
||||
|
||||
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
|
||||
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
|
||||
|
||||
out = nn.functional.linear(quant_input, quant_weight)
|
||||
if not self.bias is None:
|
||||
out += self.bias.view(1, -1).expand_as(out)
|
||||
|
||||
return out
|
||||
Reference in New Issue
Block a user