149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
from dataclasses import dataclass
|
|
from typing import Any, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
from mlx_lm.models.activations import swiglu
|
|
from mlx_lm.models.base import BaseModelArgs, create_attention_mask
|
|
|
|
|
|
@dataclass
|
|
class ModelArgs(BaseModelArgs):
|
|
model_type: str
|
|
hidden_size: int
|
|
num_hidden_layers: int
|
|
intermediate_size: int
|
|
num_attention_heads: int
|
|
vocab_size: int
|
|
num_key_value_heads: int = 0
|
|
rope_theta: float = 10000.0
|
|
tie_word_embeddings: bool = False
|
|
attention_bias: bool = False
|
|
clip_qkv: Optional[float] = None
|
|
|
|
def __post_init__(self):
|
|
if not self.num_key_value_heads:
|
|
self.num_key_value_heads = self.num_attention_heads
|
|
if self.num_key_value_heads != self.num_attention_heads:
|
|
raise ValueError("Grouped-query attention is not yet implemented for this OLMo staging converter.")
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
dim = args.hidden_size
|
|
self.n_heads = args.num_attention_heads
|
|
self.head_dim = dim // self.n_heads
|
|
self.scale = self.head_dim**-0.5
|
|
self.clip_qkv = args.clip_qkv
|
|
|
|
self.q_proj = nn.Linear(dim, dim, bias=args.attention_bias)
|
|
self.k_proj = nn.Linear(dim, dim, bias=args.attention_bias)
|
|
self.v_proj = nn.Linear(dim, dim, bias=args.attention_bias)
|
|
self.o_proj = nn.Linear(dim, dim, bias=args.attention_bias)
|
|
self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Any] = None,
|
|
) -> mx.array:
|
|
bsz, seq_len, _ = x.shape
|
|
q = self.q_proj(x)
|
|
k = self.k_proj(x)
|
|
v = self.v_proj(x)
|
|
|
|
if self.clip_qkv is not None:
|
|
q = mx.clip(q, -self.clip_qkv, self.clip_qkv)
|
|
k = mx.clip(k, -self.clip_qkv, self.clip_qkv)
|
|
v = mx.clip(v, -self.clip_qkv, self.clip_qkv)
|
|
|
|
q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
|
|
|
|
if cache is not None:
|
|
q = self.rope(q, offset=cache.offset)
|
|
k = self.rope(k, offset=cache.offset)
|
|
k, v = cache.update_and_fetch(k, v)
|
|
else:
|
|
q = self.rope(q)
|
|
k = self.rope(k)
|
|
|
|
out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask)
|
|
out = out.transpose(0, 2, 1, 3).reshape(bsz, seq_len, -1)
|
|
return self.o_proj(out)
|
|
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
dim = args.hidden_size
|
|
hidden = args.intermediate_size
|
|
self.gate_proj = nn.Linear(dim, hidden, bias=False)
|
|
self.up_proj = nn.Linear(dim, hidden, bias=False)
|
|
self.down_proj = nn.Linear(hidden, dim, bias=False)
|
|
|
|
def __call__(self, x: mx.array) -> mx.array:
|
|
return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x)))
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
dim = args.hidden_size
|
|
self.self_attn = Attention(args)
|
|
self.mlp = MLP(args)
|
|
self.input_layernorm = nn.LayerNorm(dim, affine=False)
|
|
self.post_attention_layernorm = nn.LayerNorm(dim, affine=False)
|
|
|
|
def __call__(
|
|
self,
|
|
x: mx.array,
|
|
mask: Optional[mx.array] = None,
|
|
cache: Optional[Any] = None,
|
|
) -> mx.array:
|
|
h = x + self.self_attn(self.input_layernorm(x), mask, cache)
|
|
return h + self.mlp(self.post_attention_layernorm(h))
|
|
|
|
|
|
class InnerModel(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size)
|
|
self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)]
|
|
self.norm = nn.LayerNorm(args.hidden_size, affine=False)
|
|
|
|
def __call__(self, inputs: mx.array, cache=None):
|
|
h = self.embed_tokens(inputs)
|
|
if cache is None:
|
|
cache = [None] * len(self.layers)
|
|
mask = create_attention_mask(h, cache[0])
|
|
for layer, layer_cache in zip(self.layers, cache):
|
|
h = layer(h, mask, layer_cache)
|
|
h = self.norm(h)
|
|
return h, cache
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, args: ModelArgs):
|
|
super().__init__()
|
|
self.model_type = args.model_type
|
|
self.model = InnerModel(args)
|
|
self.args = args
|
|
self.tie_word_embeddings = args.tie_word_embeddings
|
|
if not self.tie_word_embeddings:
|
|
self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False)
|
|
|
|
def __call__(self, inputs: mx.array, cache=None):
|
|
h, cache = self.model(inputs, cache)
|
|
if self.tie_word_embeddings:
|
|
return self.model.embed_tokens.as_linear(h), cache
|
|
return self.lm_head(h), cache
|
|
|
|
@property
|
|
def layers(self):
|
|
return self.model.layers
|