120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
"""Loader NANO-v3.1 UNIVERSAL (Inference Only)"""
|
|
import os
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
|
|
|
|
|
class TrueQuantLinear(nn.Module):
|
|
def __init__(self, pq, ps, pi, dq, ds, di, out_features, bias=None, bits=8, device="cuda:0"):
|
|
super().__init__()
|
|
self.out_features = out_features
|
|
self.bits = int(bits)
|
|
self.register_buffer("pq", pq.to(device=device, dtype=torch.int8))
|
|
self.register_buffer("ps", ps.to(device=device, dtype=torch.float16))
|
|
self.register_buffer("pi", pi.to(device=device, dtype=torch.long))
|
|
self.register_buffer("dq", dq.to(device=device, dtype=torch.int8))
|
|
self.register_buffer("ds", ds.to(device=device, dtype=torch.float16))
|
|
self.register_buffer("di", di.to(device=device, dtype=torch.long))
|
|
if bias is not None:
|
|
self.register_buffer("bias", bias.to(device=device, dtype=torch.float16))
|
|
else:
|
|
self.bias = None
|
|
|
|
def forward(self, x):
|
|
d, dt = x.device, x.dtype
|
|
f = x.to(torch.float16).reshape(-1, x.shape[-1])
|
|
o = torch.zeros(f.shape[0], self.out_features, dtype=torch.float16, device=d)
|
|
if self.pq.shape[0] > 0:
|
|
o.index_copy_(-1, self.pi.to(d), f @ (self.pq.to(d, torch.float16) * self.ps.to(d).unsqueeze(1)).t())
|
|
if self.dq.shape[0] > 0:
|
|
o.index_copy_(-1, self.di.to(d), f @ (self.dq.to(d, torch.float16) * self.ds.to(d).unsqueeze(1)).t())
|
|
if self.bias is not None:
|
|
o = o + self.bias.to(d)
|
|
return o.reshape(*x.shape[:-1], self.out_features).to(dt)
|
|
|
|
|
|
def _set(root, name, value):
|
|
parts = name.split(".")
|
|
parent = root
|
|
for p in parts[:-1]:
|
|
parent = parent[int(p)] if p.isdigit() else getattr(parent, p)
|
|
if parts[-1].isdigit():
|
|
parent[int(parts[-1])] = value
|
|
else:
|
|
setattr(parent, parts[-1], value)
|
|
|
|
|
|
def get_module(root, name):
|
|
cur = root
|
|
for p in name.split("."):
|
|
cur = cur[int(p)] if p.isdigit() else getattr(cur, p)
|
|
return cur
|
|
|
|
|
|
def load_artifact(artifact_dir):
|
|
d = Path(artifact_dir)
|
|
spec = json.loads((d / "spec.json").read_text("utf-8"))
|
|
state = torch.load(d / "quantized_modules.pt", map_location="cpu")
|
|
|
|
use_4bit = os.getenv("NANO_LOAD_4BIT", "0").strip().lower() in {"1", "true", "yes", "on"}
|
|
qcfg = (
|
|
BitsAndBytesConfig(
|
|
load_in_4bit=True,
|
|
bnb_4bit_quant_type="nf4",
|
|
bnb_4bit_compute_dtype=torch.float16,
|
|
)
|
|
if use_4bit
|
|
else BitsAndBytesConfig(load_in_8bit=True)
|
|
)
|
|
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
str(d),
|
|
quantization_config=qcfg,
|
|
device_map="auto",
|
|
)
|
|
tokenizer = AutoTokenizer.from_pretrained(str(d), use_fast=True)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
for name, s in state.items():
|
|
dev = next(get_module(model, name).parameters()).device
|
|
bits = s["bits"]
|
|
if "deg_q_packed" in s:
|
|
pk, pad = s["deg_q_packed"], s["pad"]
|
|
if bits == 2:
|
|
dq = torch.stack([pk & 3, (pk >> 2) & 3, (pk >> 4) & 3, (pk >> 6) & 3], dim=-1).view(pk.shape[0], -1)
|
|
if pad > 0:
|
|
dq = dq[:, :-pad]
|
|
dq = dq.to(torch.int8) - 1
|
|
else:
|
|
dq = torch.stack([pk & 15, (pk >> 4) & 15], dim=-1).view(pk.shape[0], -1)
|
|
if pad > 0:
|
|
dq = dq[:, :-pad]
|
|
dq = dq.to(torch.int8) - 7
|
|
else:
|
|
dq = s.get("deg_q", torch.zeros(0, dtype=torch.int8))
|
|
|
|
_set(
|
|
model,
|
|
name,
|
|
TrueQuantLinear(
|
|
s["prot_q"],
|
|
s["prot_scale"],
|
|
s["prot_idx"],
|
|
dq,
|
|
s["deg_scale"],
|
|
s["deg_idx"],
|
|
s["out_features"],
|
|
s.get("bias"),
|
|
bits,
|
|
device=str(dev),
|
|
),
|
|
)
|
|
return model.eval(), tokenizer, spec
|
|
|
|
|