add pkgs
This commit is contained in:
141
examples/bloom/smoothquant.py
Normal file
141
examples/bloom/smoothquant.py
Normal file
@@ -0,0 +1,141 @@
|
||||
'''
|
||||
Utilities for SmoothQuant models
|
||||
'''
|
||||
|
||||
import functools
|
||||
from collections import defaultdict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from tqdm import tqdm
|
||||
from transformers.pytorch_utils import Conv1D
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def apply_smoothing(scales,
|
||||
gemm_weights,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
dtype=torch.float32,
|
||||
layernorm_1p=False):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
|
||||
if layernorm_weights is not None:
|
||||
assert layernorm_weights.numel() == scales.numel()
|
||||
layernorm_weights.div_(scales).to(dtype)
|
||||
if layernorm_bias is not None:
|
||||
assert layernorm_bias.numel() == scales.numel()
|
||||
layernorm_bias.div_(scales).to(dtype)
|
||||
if layernorm_1p:
|
||||
layernorm_weights += (1 / scales) - 1
|
||||
|
||||
for gemm in gemm_weights:
|
||||
gemm.mul_(scales.view(1, -1)).to(dtype)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_gemm(gemm_weights,
|
||||
act_scales,
|
||||
layernorm_weights=None,
|
||||
layernorm_bias=None,
|
||||
alpha=0.5,
|
||||
weight_scales=None):
|
||||
if not isinstance(gemm_weights, list):
|
||||
gemm_weights = [gemm_weights]
|
||||
orig_dtype = gemm_weights[0].dtype
|
||||
|
||||
for gemm in gemm_weights:
|
||||
# gemm_weights are expected to be transposed
|
||||
assert gemm.shape[1] == act_scales.numel()
|
||||
|
||||
if weight_scales is None:
|
||||
weight_scales = torch.cat(
|
||||
[gemm.abs().max(dim=0, keepdim=True)[0] for gemm in gemm_weights],
|
||||
dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0]
|
||||
weight_scales.to(float).clamp(min=1e-5)
|
||||
scales = (act_scales.to(gemm_weights[0].device).to(float).pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5)
|
||||
|
||||
apply_smoothing(scales, gemm_weights, layernorm_weights, layernorm_bias,
|
||||
orig_dtype)
|
||||
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def smooth_ln_fcs(ln, fcs, act_scales, alpha=0.5):
|
||||
if not isinstance(fcs, list):
|
||||
fcs = [fcs]
|
||||
for fc in fcs:
|
||||
assert isinstance(fc, nn.Linear)
|
||||
assert ln.weight.numel() == fc.in_features == act_scales.numel()
|
||||
|
||||
device, dtype = fcs[0].weight.device, fcs[0].weight.dtype
|
||||
act_scales = act_scales.to(device=device, dtype=dtype)
|
||||
weight_scales = torch.cat(
|
||||
[fc.weight.abs().max(dim=0, keepdim=True)[0] for fc in fcs], dim=0)
|
||||
weight_scales = weight_scales.max(dim=0)[0].clamp(min=1e-5)
|
||||
|
||||
scales = (act_scales.pow(alpha) /
|
||||
weight_scales.pow(1 - alpha)).clamp(min=1e-5).to(device).to(dtype)
|
||||
|
||||
if ln is not None:
|
||||
ln.weight.div_(scales)
|
||||
ln.bias.div_(scales)
|
||||
|
||||
for fc in fcs:
|
||||
fc.weight.mul_(scales.view(1, -1))
|
||||
return scales
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def capture_activation_range(model,
|
||||
tokenizer,
|
||||
dataset,
|
||||
num_samples=512,
|
||||
seq_len=512):
|
||||
model.eval()
|
||||
device = next(model.parameters()).device
|
||||
act_scales = defaultdict(lambda: {"x": None, "y": None, "w": None})
|
||||
|
||||
def stat_tensor(name, tensor, act_scales, key):
|
||||
hidden_dim = tensor.shape[-1]
|
||||
tensor = tensor.view(-1, hidden_dim).abs().detach()
|
||||
comming_max = torch.max(tensor, dim=0)[0].float()
|
||||
|
||||
if act_scales[name][key] is None:
|
||||
act_scales[name][key] = comming_max
|
||||
else:
|
||||
act_scales[name][key] = torch.max(act_scales[name][key],
|
||||
comming_max)
|
||||
|
||||
def stat_input_hook(m, x, y, name):
|
||||
if isinstance(x, tuple):
|
||||
x = x[0]
|
||||
stat_tensor(name, x, act_scales, "x")
|
||||
stat_tensor(name, y, act_scales, "y")
|
||||
|
||||
if act_scales[name]["w"] is None:
|
||||
act_scales[name]["w"] = m.weight.abs().clip(1e-8,
|
||||
None).max(dim=1)[0]
|
||||
|
||||
hooks = []
|
||||
for name, m in model.named_modules():
|
||||
if isinstance(m, nn.Linear) or isinstance(m, Conv1D):
|
||||
hooks.append(
|
||||
m.register_forward_hook(
|
||||
functools.partial(stat_input_hook, name=name)))
|
||||
|
||||
for i in tqdm(range(num_samples), desc="calibrating model"):
|
||||
input_ids = tokenizer(dataset[i]["text"],
|
||||
return_tensors="pt",
|
||||
max_length=seq_len,
|
||||
truncation=True).input_ids.to(device)
|
||||
model(input_ids)
|
||||
|
||||
for h in hooks:
|
||||
h.remove()
|
||||
|
||||
return act_scales
|
||||
Reference in New Issue
Block a user