from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx import mlx.nn as nn from mlx_lm.models.activations import swiglu from mlx_lm.models.base import BaseModelArgs, create_attention_mask @dataclass class ModelArgs(BaseModelArgs): model_type: str hidden_size: int num_hidden_layers: int intermediate_size: int num_attention_heads: int vocab_size: int num_key_value_heads: int = 0 rope_theta: float = 10000.0 tie_word_embeddings: bool = False attention_bias: bool = False clip_qkv: Optional[float] = None def __post_init__(self): if not self.num_key_value_heads: self.num_key_value_heads = self.num_attention_heads if self.num_key_value_heads != self.num_attention_heads: raise ValueError("Grouped-query attention is not yet implemented for this OLMo staging converter.") class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.n_heads = args.num_attention_heads self.head_dim = dim // self.n_heads self.scale = self.head_dim**-0.5 self.clip_qkv = args.clip_qkv self.q_proj = nn.Linear(dim, dim, bias=args.attention_bias) self.k_proj = nn.Linear(dim, dim, bias=args.attention_bias) self.v_proj = nn.Linear(dim, dim, bias=args.attention_bias) self.o_proj = nn.Linear(dim, dim, bias=args.attention_bias) self.rope = nn.RoPE(self.head_dim, traditional=False, base=args.rope_theta) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: bsz, seq_len, _ = x.shape q = self.q_proj(x) k = self.k_proj(x) v = self.v_proj(x) if self.clip_qkv is not None: q = mx.clip(q, -self.clip_qkv, self.clip_qkv) k = mx.clip(k, -self.clip_qkv, self.clip_qkv) v = mx.clip(v, -self.clip_qkv, self.clip_qkv) q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(0, 2, 1, 3) if cache is not None: q = self.rope(q, offset=cache.offset) k = self.rope(k, offset=cache.offset) k, v = cache.update_and_fetch(k, v) else: q = self.rope(q) k = self.rope(k) out = mx.fast.scaled_dot_product_attention(q, k, v, scale=self.scale, mask=mask) out = out.transpose(0, 2, 1, 3).reshape(bsz, seq_len, -1) return self.o_proj(out) class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size hidden = args.intermediate_size self.gate_proj = nn.Linear(dim, hidden, bias=False) self.up_proj = nn.Linear(dim, hidden, bias=False) self.down_proj = nn.Linear(hidden, dim, bias=False) def __call__(self, x: mx.array) -> mx.array: return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) class DecoderLayer(nn.Module): def __init__(self, args: ModelArgs): super().__init__() dim = args.hidden_size self.self_attn = Attention(args) self.mlp = MLP(args) self.input_layernorm = nn.LayerNorm(dim, affine=False) self.post_attention_layernorm = nn.LayerNorm(dim, affine=False) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: h = x + self.self_attn(self.input_layernorm(x), mask, cache) return h + self.mlp(self.post_attention_layernorm(h)) class InnerModel(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) self.layers = [DecoderLayer(args) for _ in range(args.num_hidden_layers)] self.norm = nn.LayerNorm(args.hidden_size, affine=False) def __call__(self, inputs: mx.array, cache=None): h = self.embed_tokens(inputs) if cache is None: cache = [None] * len(self.layers) mask = create_attention_mask(h, cache[0]) for layer, layer_cache in zip(self.layers, cache): h = layer(h, mask, layer_cache) h = self.norm(h) return h, cache class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.model_type = args.model_type self.model = InnerModel(args) self.args = args self.tie_word_embeddings = args.tie_word_embeddings if not self.tie_word_embeddings: self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) def __call__(self, inputs: mx.array, cache=None): h, cache = self.model(inputs, cache) if self.tie_word_embeddings: return self.model.embed_tokens.as_linear(h), cache return self.lm_head(h), cache @property def layers(self): return self.model.layers