初始化项目,由ModelHub XC社区提供模型
Model: RthItalia/NanoLLM-Qwen2.5-7B-v3.1 Source: Original Platform
This commit is contained in:
119
load_artifact.py
Normal file
119
load_artifact.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Loader NANO-v3.1 UNIVERSAL (Inference Only)"""
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
|
||||
|
||||
|
||||
class TrueQuantLinear(nn.Module):
|
||||
def __init__(self, pq, ps, pi, dq, ds, di, out_features, bias=None, bits=8, device="cuda:0"):
|
||||
super().__init__()
|
||||
self.out_features = out_features
|
||||
self.bits = int(bits)
|
||||
self.register_buffer("pq", pq.to(device=device, dtype=torch.int8))
|
||||
self.register_buffer("ps", ps.to(device=device, dtype=torch.float16))
|
||||
self.register_buffer("pi", pi.to(device=device, dtype=torch.long))
|
||||
self.register_buffer("dq", dq.to(device=device, dtype=torch.int8))
|
||||
self.register_buffer("ds", ds.to(device=device, dtype=torch.float16))
|
||||
self.register_buffer("di", di.to(device=device, dtype=torch.long))
|
||||
if bias is not None:
|
||||
self.register_buffer("bias", bias.to(device=device, dtype=torch.float16))
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def forward(self, x):
|
||||
d, dt = x.device, x.dtype
|
||||
f = x.to(torch.float16).reshape(-1, x.shape[-1])
|
||||
o = torch.zeros(f.shape[0], self.out_features, dtype=torch.float16, device=d)
|
||||
if self.pq.shape[0] > 0:
|
||||
o.index_copy_(-1, self.pi.to(d), f @ (self.pq.to(d, torch.float16) * self.ps.to(d).unsqueeze(1)).t())
|
||||
if self.dq.shape[0] > 0:
|
||||
o.index_copy_(-1, self.di.to(d), f @ (self.dq.to(d, torch.float16) * self.ds.to(d).unsqueeze(1)).t())
|
||||
if self.bias is not None:
|
||||
o = o + self.bias.to(d)
|
||||
return o.reshape(*x.shape[:-1], self.out_features).to(dt)
|
||||
|
||||
|
||||
def _set(root, name, value):
|
||||
parts = name.split(".")
|
||||
parent = root
|
||||
for p in parts[:-1]:
|
||||
parent = parent[int(p)] if p.isdigit() else getattr(parent, p)
|
||||
if parts[-1].isdigit():
|
||||
parent[int(parts[-1])] = value
|
||||
else:
|
||||
setattr(parent, parts[-1], value)
|
||||
|
||||
|
||||
def get_module(root, name):
|
||||
cur = root
|
||||
for p in name.split("."):
|
||||
cur = cur[int(p)] if p.isdigit() else getattr(cur, p)
|
||||
return cur
|
||||
|
||||
|
||||
def load_artifact(artifact_dir):
|
||||
d = Path(artifact_dir)
|
||||
spec = json.loads((d / "spec.json").read_text("utf-8"))
|
||||
state = torch.load(d / "quantized_modules.pt", map_location="cpu")
|
||||
|
||||
use_4bit = os.getenv("NANO_LOAD_4BIT", "0").strip().lower() in {"1", "true", "yes", "on"}
|
||||
qcfg = (
|
||||
BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
)
|
||||
if use_4bit
|
||||
else BitsAndBytesConfig(load_in_8bit=True)
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
str(d),
|
||||
quantization_config=qcfg,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(str(d), use_fast=True)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
for name, s in state.items():
|
||||
dev = next(get_module(model, name).parameters()).device
|
||||
bits = s["bits"]
|
||||
if "deg_q_packed" in s:
|
||||
pk, pad = s["deg_q_packed"], s["pad"]
|
||||
if bits == 2:
|
||||
dq = torch.stack([pk & 3, (pk >> 2) & 3, (pk >> 4) & 3, (pk >> 6) & 3], dim=-1).view(pk.shape[0], -1)
|
||||
if pad > 0:
|
||||
dq = dq[:, :-pad]
|
||||
dq = dq.to(torch.int8) - 1
|
||||
else:
|
||||
dq = torch.stack([pk & 15, (pk >> 4) & 15], dim=-1).view(pk.shape[0], -1)
|
||||
if pad > 0:
|
||||
dq = dq[:, :-pad]
|
||||
dq = dq.to(torch.int8) - 7
|
||||
else:
|
||||
dq = s.get("deg_q", torch.zeros(0, dtype=torch.int8))
|
||||
|
||||
_set(
|
||||
model,
|
||||
name,
|
||||
TrueQuantLinear(
|
||||
s["prot_q"],
|
||||
s["prot_scale"],
|
||||
s["prot_idx"],
|
||||
dq,
|
||||
s["deg_scale"],
|
||||
s["deg_idx"],
|
||||
s["out_features"],
|
||||
s.get("bias"),
|
||||
bits,
|
||||
device=str(dev),
|
||||
),
|
||||
)
|
||||
return model.eval(), tokenizer, spec
|
||||
|
||||
|
||||
Reference in New Issue
Block a user