119 lines
5.2 KiB
Python
119 lines
5.2 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
|
|
|
|
class NanoInt8Linear(nn.Module):
|
|
def __init__(self, in_features, out_features, has_bias=False):
|
|
super().__init__()
|
|
self.in_features = int(in_features)
|
|
self.out_features = int(out_features)
|
|
self.has_bias = bool(has_bias)
|
|
self.register_buffer("q", torch.empty((self.out_features, self.in_features), dtype=torch.int8))
|
|
self.register_buffer("scale", torch.empty((self.out_features,), dtype=torch.float16))
|
|
if self.has_bias:
|
|
self.register_buffer("bias", torch.empty((self.out_features,), dtype=torch.float16))
|
|
|
|
def forward(self, x):
|
|
dt = x.dtype
|
|
f = x.to(torch.float16).reshape(-1, x.shape[-1])
|
|
w = self.q.to(f.device, torch.float16) * self.scale.to(f.device).unsqueeze(1)
|
|
y = f @ w.t()
|
|
if self.has_bias:
|
|
y = y + self.bias.to(f.device)
|
|
return y.reshape(*x.shape[:-1], self.out_features).to(dt)
|
|
|
|
class NanoTrueQuantLinear(nn.Module):
|
|
def __init__(self, in_features, out_features, prot_rows, deg_rows, has_bias=False):
|
|
super().__init__()
|
|
self.in_features = int(in_features)
|
|
self.out_features = int(out_features)
|
|
self.has_bias = bool(has_bias)
|
|
self.register_buffer("prot_q", torch.empty((prot_rows, self.in_features), dtype=torch.int8))
|
|
self.register_buffer("prot_scale", torch.empty((prot_rows,), dtype=torch.float16))
|
|
self.register_buffer("prot_idx", torch.empty((prot_rows,), dtype=torch.long))
|
|
self.register_buffer("deg_q", torch.empty((deg_rows, self.in_features), dtype=torch.int8))
|
|
self.register_buffer("deg_scale", torch.empty((deg_rows,), dtype=torch.float16))
|
|
self.register_buffer("deg_idx", torch.empty((deg_rows,), dtype=torch.long))
|
|
if self.has_bias:
|
|
self.register_buffer("bias", torch.empty((self.out_features,), dtype=torch.float16))
|
|
|
|
def forward(self, x):
|
|
dt = x.dtype
|
|
f = x.to(torch.float16).reshape(-1, x.shape[-1])
|
|
y = torch.zeros((f.shape[0], self.out_features), dtype=torch.float16, device=f.device)
|
|
if self.prot_q.shape[0] > 0:
|
|
w = self.prot_q.to(f.device, torch.float16) * self.prot_scale.to(f.device).unsqueeze(1)
|
|
y.index_copy_(-1, self.prot_idx.to(f.device), f @ w.t())
|
|
if self.deg_q.shape[0] > 0:
|
|
w = self.deg_q.to(f.device, torch.float16) * self.deg_scale.to(f.device).unsqueeze(1)
|
|
y.index_copy_(-1, self.deg_idx.to(f.device), f @ w.t())
|
|
if self.has_bias:
|
|
y = y + self.bias.to(f.device)
|
|
return y.reshape(*x.shape[:-1], self.out_features).to(dt)
|
|
|
|
class NanoEmbedding(nn.Module):
|
|
def __init__(self, num_embeddings, embedding_dim):
|
|
super().__init__()
|
|
self.num_embeddings = int(num_embeddings)
|
|
self.embedding_dim = int(embedding_dim)
|
|
self.register_buffer("q", torch.empty((self.num_embeddings, self.embedding_dim), dtype=torch.int8))
|
|
self.register_buffer("scale", torch.empty((self.num_embeddings,), dtype=torch.float16))
|
|
|
|
def forward(self, input_ids):
|
|
return self.q[input_ids].to(torch.float16) * self.scale[input_ids].to(torch.float16).unsqueeze(-1)
|
|
|
|
|
|
|
|
class NanoTiedLMHead(nn.Module):
|
|
def __init__(self, embedding):
|
|
super().__init__()
|
|
self.register_buffer("q", embedding.q.detach().clone())
|
|
self.register_buffer("scale", embedding.scale.detach().clone())
|
|
|
|
def forward(self, x):
|
|
w = self.q.to(x.device, torch.float16) * self.scale.to(x.device).unsqueeze(1)
|
|
return x.to(torch.float16) @ w.t()
|
|
|
|
def _set_module(root, name, module):
|
|
cur = root
|
|
parts = name.split(".")
|
|
for p in parts[:-1]:
|
|
cur = cur[int(p)] if p.isdigit() else getattr(cur, p)
|
|
setattr(cur, parts[-1], module)
|
|
|
|
class NanoQwenForCausalLM(Qwen2ForCausalLM):
|
|
config_class = Qwen2Config
|
|
|
|
def tie_weights(self, *args, **kwargs):
|
|
return None
|
|
|
|
def mark_tied_weights_as_initialized(self, *args, **kwargs):
|
|
return None
|
|
|
|
def __init__(self, config):
|
|
config.tie_word_embeddings = False
|
|
super().__init__(config)
|
|
self.config.tie_word_embeddings = False
|
|
self._tied_weights_keys = []
|
|
self.all_tied_weights_keys = {}
|
|
mods = getattr(config, "nanollm_modules", {})
|
|
for name, spec in mods.items():
|
|
kind = spec["kind"]
|
|
if kind == "embedding":
|
|
mod = NanoEmbedding(spec["num_embeddings"], spec["embedding_dim"])
|
|
elif kind == "int8_linear":
|
|
mod = NanoInt8Linear(spec["in_features"], spec["out_features"], spec.get("has_bias", False))
|
|
elif kind == "truequant_linear":
|
|
mod = NanoTrueQuantLinear(
|
|
spec["in_features"], spec["out_features"],
|
|
spec["prot_rows"], spec["deg_rows"],
|
|
spec.get("has_bias", False),
|
|
)
|
|
else:
|
|
raise ValueError(f"Unknown Nano module kind: {kind}")
|
|
_set_module(self, name, mod)
|
|
if "lm_head" not in mods and isinstance(self.model.embed_tokens, NanoEmbedding):
|
|
self.lm_head = NanoTiedLMHead(self.model.embed_tokens)
|