48 lines
1.3 KiB
Python
48 lines
1.3 KiB
Python
|
|
import math
|
||
|
|
import torch
|
||
|
|
from torch import nn
|
||
|
|
|
||
|
|
|
||
|
|
def weight_quant(weight, num_bits=1):
|
||
|
|
dtype = weight.dtype
|
||
|
|
weight = weight.float()
|
||
|
|
s = 1 / weight.abs().mean().clamp(min=1e-5)
|
||
|
|
result = (weight * s).round().clamp(-1, 1) / s
|
||
|
|
return result.type(dtype)
|
||
|
|
|
||
|
|
|
||
|
|
def activation_quant(x, num_bits=8):
|
||
|
|
dtype = x.dtype
|
||
|
|
x = x.float()
|
||
|
|
Qn = -2 ** (num_bits - 1)
|
||
|
|
Qp = 2 ** (num_bits - 1) - 1
|
||
|
|
s = Qp / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
|
||
|
|
result = (x * s).round().clamp(Qn, Qp) / s
|
||
|
|
return result.type(dtype)
|
||
|
|
|
||
|
|
|
||
|
|
class BitLinear(nn.Linear):
|
||
|
|
|
||
|
|
def __init__(self,
|
||
|
|
*kargs,
|
||
|
|
weight_bits=1,
|
||
|
|
input_bits=8,
|
||
|
|
**kwargs
|
||
|
|
):
|
||
|
|
super(BitLinear, self).__init__(*kargs, **kwargs)
|
||
|
|
"""
|
||
|
|
RMSNorm is placed outside BitLinear
|
||
|
|
"""
|
||
|
|
self.weight_bits = weight_bits
|
||
|
|
self.input_bits = input_bits
|
||
|
|
|
||
|
|
def forward(self, input):
|
||
|
|
|
||
|
|
quant_input = input + (activation_quant(input, self.input_bits) - input).detach()
|
||
|
|
quant_weight = self.weight + (weight_quant(self.weight, self.weight_bits) - self.weight).detach()
|
||
|
|
|
||
|
|
out = nn.functional.linear(quant_input, quant_weight)
|
||
|
|
if not self.bias is None:
|
||
|
|
out += self.bias.view(1, -1).expand_as(out)
|
||
|
|
|
||
|
|
return out
|