初始化项目,由ModelHub XC社区提供模型

Model: divakar-yadav/transformer-1b-chat
Source: Original Platform
This commit is contained in:
ModelHub XC
2026-06-20 17:27:58 +08:00
commit 070e055bf5
25 changed files with 273003 additions and 0 deletions

318
training_code/chat.py Normal file
View File

@@ -0,0 +1,318 @@
#!/usr/bin/env python3
"""
Interactive chat with the 1B Transformer.
Runs in an infinite conversation loop from the terminal.
Usage:
python chat.py # auto-find latest checkpoint
python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint
"""
import sys
import os
import glob
import time
import torch
import torch.nn.functional as F
import readline # enables arrow keys and history in input()
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
def find_latest_checkpoint():
"""Look for DPO > SFT > pretrained checkpoint."""
dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo"
sft_dir = "/jfs/deepak-kumar/checkpoints_sft"
pt_dir = "/jfs/deepak-kumar/checkpoints"
# Prefer DPO final
dpo_final = os.path.join(dpo_dir, "dpo_final.pt")
if os.path.exists(dpo_final):
return dpo_final, True
dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt"))
if dpo_files:
return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True
# Then SFT
sft_final = os.path.join(sft_dir, "sft_final.pt")
if os.path.exists(sft_final):
return sft_final, True
sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt"))
if sft_files:
return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True
# Fall back to pretrained
pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt"))
if pt_files:
return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False
return None, False
def load_model(checkpoint_path, tokenizer, device="cuda:0"):
config = ModelConfig()
model = Transformer(config)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Handle expanded vocab from SFT
saved_vocab = ckpt.get("vocab_size", config.vocab_size)
if saved_vocab > config.vocab_size:
config.vocab_size = saved_vocab
model = Transformer(config)
model.load_state_dict(ckpt["model"])
model = model.to(device).bfloat16().eval()
step = ckpt.get("step", "?")
loss = ckpt.get("loss", "?")
del ckpt
torch.cuda.empty_cache()
return model, config, step, loss
@torch.no_grad()
def generate_stream(model, tokenizer, prompt, max_new_tokens=512,
temperature=0.8, top_k=50, top_p=0.9,
repetition_penalty=1.15, device="cuda:0",
stop_token_ids=None):
"""Generate tokens one at a time, yielding each for streaming output."""
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
generated_ids = []
prev_decoded_len = 0
if stop_token_ids is None:
stop_token_ids = set()
else:
stop_token_ids = set(stop_token_ids)
stop_token_ids.add(tokenizer.eos_token_id)
for _ in range(max_new_tokens):
if input_ids.shape[1] >= model.config.max_seq_len:
break
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(input_ids)
logits = logits[:, -1, :]
if repetition_penalty != 1.0 and generated_ids:
prev_tokens = torch.tensor(generated_ids, device=device).unique()
for token_id in prev_tokens:
if logits[0, token_id] > 0:
logits[0, token_id] /= repetition_penalty
else:
logits[0, token_id] *= repetition_penalty
logits = logits / temperature
if top_k > 0:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[:, -1:]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[mask] = float("-inf")
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
token_id = next_token.item()
# Stop on any stop token (EOS, <|end|>, <|user|>)
if token_id in stop_token_ids:
break
generated_ids.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=1)
full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
new_text = full_decoded[prev_decoded_len:]
prev_decoded_len = len(full_decoded)
yield new_text
return
def print_banner(step, loss, device):
print("\033[1;36m") # cyan bold
print("=" * 60)
print(" 1B TRANSFORMER — Interactive Chat")
print("=" * 60)
print(f"\033[0m Checkpoint : step {step}")
print(f" Loss : {loss}")
print(f" Device : {device}")
print(f" Parameters : 1.106B")
print()
print(" \033[90mCommands:\033[0m")
print(" \033[33m/quit\033[0m — exit")
print(" \033[33m/clear\033[0m — clear conversation context")
print(" \033[33m/temp N\033[0m — set temperature (default 0.8)")
print(" \033[33m/tokens N\033[0m — set max tokens (default 512)")
print(" \033[33m/topp N\033[0m — set top-p (default 0.9)")
print(" \033[33m/topk N\033[0m — set top-k (default 50)")
print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)")
print()
print("\033[90m" + "" * 60 + "\033[0m")
def main():
device = "cuda:0"
is_sft = False
if len(sys.argv) > 1:
checkpoint = sys.argv[1]
is_sft = "sft" in checkpoint.lower()
else:
result = find_latest_checkpoint()
if result[0] is None:
print("No checkpoint found!")
sys.exit(1)
checkpoint, is_sft = result
tokenizer = get_tokenizer()
# Add chat tokens for SFT models
if is_sft:
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
print(f"\n Loading model from {checkpoint}...")
print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}")
model, config, step, loss = load_model(checkpoint, tokenizer, device)
print(f" Model loaded!\n")
print_banner(step, loss, device)
if is_sft:
print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n")
# Settings
temperature = 0.7 if is_sft else 0.8
max_tokens = 512
top_p = 0.9
top_k = 50
rep_penalty = 1.15
context = ""
# Chat template tokens for SFT
USER_START = "<|user|>\n"
ASST_START = "<|assistant|>\n"
TURN_END = "\n<|end|>\n"
# Build stop token IDs for generation
sft_stop_ids = []
if is_sft:
vocab = tokenizer.get_vocab()
for tok_str in ["<|end|>", "<|user|>"]:
if tok_str in vocab:
sft_stop_ids.append(vocab[tok_str])
while True:
try:
user_input = input("\n\033[1;32mYou:\033[0m ").strip()
except (KeyboardInterrupt, EOFError):
print("\n\nGoodbye!")
break
if not user_input:
continue
# Handle commands
if user_input.startswith("/"):
cmd = user_input.lower().split()
if cmd[0] == "/quit":
print("Goodbye!")
break
elif cmd[0] == "/clear":
context = ""
print("\033[90m [Context cleared]\033[0m")
continue
elif cmd[0] == "/temp" and len(cmd) > 1:
temperature = float(cmd[1])
print(f"\033[90m [Temperature set to {temperature}]\033[0m")
continue
elif cmd[0] == "/tokens" and len(cmd) > 1:
max_tokens = int(cmd[1])
print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m")
continue
elif cmd[0] == "/topp" and len(cmd) > 1:
top_p = float(cmd[1])
print(f"\033[90m [Top-p set to {top_p}]\033[0m")
continue
elif cmd[0] == "/topk" and len(cmd) > 1:
top_k = int(cmd[1])
print(f"\033[90m [Top-k set to {top_k}]\033[0m")
continue
elif cmd[0] == "/rep" and len(cmd) > 1:
rep_penalty = float(cmd[1])
print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m")
continue
else:
print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m")
continue
# Build prompt
if is_sft:
prompt = context + USER_START + user_input + TURN_END + ASST_START
else:
if context:
prompt = context + "\n" + user_input
else:
prompt = user_input
# Trim context if too long
while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens:
if is_sft:
parts = context.split(TURN_END)
if len(parts) <= 2:
break
context = TURN_END.join(parts[2:])
prompt = context + USER_START + user_input + TURN_END + ASST_START
else:
lines = prompt.split("\n")
if len(lines) <= 2:
break
prompt = "\n".join(lines[1:])
# Generate with streaming
print("\033[1;34mModel:\033[0m ", end="", flush=True)
t0 = time.time()
full_response = ""
token_count = 0
for token_text in generate_stream(
model, tokenizer, prompt,
max_new_tokens=max_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=rep_penalty,
device=device,
stop_token_ids=sft_stop_ids if is_sft else None,
):
print(token_text, end="", flush=True)
full_response += token_text
token_count += 1
elapsed = time.time() - t0
tps = token_count / max(elapsed, 1e-9)
print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m")
# Append to context for multi-turn
if is_sft:
context = (context + USER_START + user_input + TURN_END +
ASST_START + full_response.strip() + TURN_END)
else:
context = prompt + full_response
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,167 @@
"""
Export the trained model to HuggingFace-compatible format.
Creates:
- model.safetensors (weights)
- config.json (architecture config)
- generation_config.json
- tokenizer.json, tokenizer_config.json, special_tokens_map.json
"""
import os
import sys
import json
import torch
from collections import OrderedDict
from safetensors.torch import save_file
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
CHECKPOINT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
OUTPUT_DIR = "/home/jovyan/training/hf_model"
os.makedirs(OUTPUT_DIR, exist_ok=True)
print("=" * 60)
print(" EXPORTING MODEL TO HUGGING FACE FORMAT")
print("=" * 60)
# --- 1. Load model ---
print("\n[1/4] Loading checkpoint...")
tokenizer = get_tokenizer()
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
model_config = ModelConfig()
model_config.vocab_size = len(tokenizer)
model = Transformer(model_config)
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
step = ckpt.get("step", 0)
del ckpt
print(f" Loaded DPO model (step {step}, vocab {model_config.vocab_size})")
# --- 2. Convert state dict keys to HF-style naming ---
print("\n[2/4] Converting weights to safetensors...")
state_dict = model.state_dict()
hf_state = OrderedDict()
KEY_MAP = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"output.weight": "lm_head.weight",
}
for key, tensor in state_dict.items():
if key in KEY_MAP:
hf_state[KEY_MAP[key]] = tensor
continue
if key.startswith("layers."):
parts = key.split(".")
layer_idx = parts[1]
rest = ".".join(parts[2:])
layer_map = {
"attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight",
"ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
"attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
"attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight",
"attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
"attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight",
"ffn.w_gate.weight": f"model.layers.{layer_idx}.mlp.gate_proj.weight",
"ffn.w_up.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight",
"ffn.w_down.weight": f"model.layers.{layer_idx}.mlp.down_proj.weight",
}
if rest in layer_map:
hf_state[layer_map[rest]] = tensor
else:
print(f" WARNING: unmapped key {key}")
hf_state[key] = tensor
elif key == "freqs_cis":
continue
else:
print(f" WARNING: unmapped key {key}")
hf_state[key] = tensor
# Convert all to bfloat16 for storage
for k in hf_state:
if hf_state[k].dtype == torch.float32:
hf_state[k] = hf_state[k].to(torch.bfloat16)
safetensors_path = os.path.join(OUTPUT_DIR, "model.safetensors")
save_file(hf_state, safetensors_path)
size_gb = os.path.getsize(safetensors_path) / 1e9
print(f" Saved {len(hf_state)} tensors -> {safetensors_path} ({size_gb:.2f} GB)")
# --- 3. Write config files ---
print("\n[3/4] Writing config files...")
config_json = {
"architectures": ["LlamaForCausalLM"],
"model_type": "llama",
"vocab_size": model_config.vocab_size,
"hidden_size": model_config.hidden_dim,
"intermediate_size": model_config.intermediate_dim,
"num_hidden_layers": model_config.num_layers,
"num_attention_heads": model_config.num_attention_heads,
"num_key_value_heads": model_config.num_kv_heads,
"max_position_embeddings": model_config.max_seq_len,
"rope_theta": model_config.rope_theta,
"rms_norm_eps": model_config.rms_norm_eps,
"hidden_act": "silu",
"initializer_range": 0.02,
"tie_word_embeddings": False,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0",
"use_cache": True,
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
}
with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
json.dump(config_json, f, indent=2)
print(" config.json")
gen_config = {
"bos_token_id": tokenizer.bos_token_id,
"eos_token_id": tokenizer.eos_token_id,
"pad_token_id": tokenizer.pad_token_id,
"do_sample": True,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 1.15,
"max_new_tokens": 512,
"transformers_version": "4.40.0",
}
with open(os.path.join(OUTPUT_DIR, "generation_config.json"), "w") as f:
json.dump(gen_config, f, indent=2)
print(" generation_config.json")
# --- 4. Export tokenizer ---
print("\n[4/4] Exporting tokenizer...")
tokenizer.save_pretrained(OUTPUT_DIR)
print(" Tokenizer files saved")
print("\n" + "=" * 60)
print(" EXPORT COMPLETE -> " + OUTPUT_DIR)
print("=" * 60)
print("\nFiles:")
for f in sorted(os.listdir(OUTPUT_DIR)):
size = os.path.getsize(os.path.join(OUTPUT_DIR, f))
if size > 1e6:
print(f" {f:40s} {size/1e6:.1f} MB")
else:
print(f" {f:40s} {size/1e3:.1f} KB")

134
training_code/inference.py Normal file
View File

@@ -0,0 +1,134 @@
"""
Inference script for the 1B Transformer — Single GPU.
Usage:
python inference.py # auto-finds latest checkpoint
python inference.py /path/to/checkpoint.pt # specific checkpoint
"""
import sys
import os
import glob
import time
import torch
import torch.nn.functional as F
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"):
files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
if not files:
final = os.path.join(checkpoint_dir, "final.pt")
return final if os.path.exists(final) else None
return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0]))
def load_model(checkpoint_path, device="cuda:0"):
config = ModelConfig()
model = Transformer(config)
print(f"Loading checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
model = model.to(device).bfloat16().eval()
step = ckpt.get("step", "?")
loss = ckpt.get("loss", "?")
print(f" Step: {step} | Loss: {loss}")
print(f" Params: {sum(p.numel() for p in model.parameters()):,}")
print(f" Device: {device}")
del ckpt
torch.cuda.empty_cache()
return model, config
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=200,
temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"):
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
t0 = time.time()
for i in range(max_new_tokens):
if input_ids.shape[1] >= model.config.max_seq_len:
break
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(input_ids)
logits = logits[:, -1, :] / temperature
if top_k > 0:
topk_vals, _ = torch.topk(logits, top_k)
logits[logits < topk_vals[:, -1:]] = float("-inf")
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
sorted_logits[mask] = float("-inf")
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
if next_token.item() == tokenizer.eos_token_id:
break
input_ids = torch.cat([input_ids, next_token], dim=1)
elapsed = time.time() - t0
gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt))
tok_per_sec = gen_tokens / max(elapsed, 1e-9)
text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
return text, gen_tokens, tok_per_sec
def main():
device = "cuda:0"
if len(sys.argv) > 1:
checkpoint = sys.argv[1]
else:
checkpoint = find_latest_checkpoint()
if checkpoint is None:
print("No checkpoint found!")
sys.exit(1)
model, config = load_model(checkpoint, device)
tokenizer = get_tokenizer()
prompts = [
"The meaning of life is",
"In machine learning, a neural network",
"The capital of France is",
"Once upon a time, there was a",
"To solve a quadratic equation, you need to",
"The theory of relativity explains that",
"Python is a programming language that",
"The sun rises in the east and",
]
print("\n" + "=" * 70)
print(" INFERENCE — 1B Transformer (Single GPU)")
print("=" * 70)
for prompt in prompts:
print(f"\n{'' * 60}")
print(f"PROMPT: {prompt}")
print(f"{'' * 60}")
text, n_tok, tps = generate(model, tokenizer, prompt,
max_new_tokens=150, temperature=0.8,
top_k=50, device=device)
generated = text[len(prompt):]
print(f"OUTPUT:{generated}")
print(f" [{n_tok} tokens, {tps:.1f} tok/s]")
print("\n" + "=" * 70)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,2 @@
from .config import ModelConfig, TrainConfig
from .transformer import Transformer

View File

@@ -0,0 +1,78 @@
"""
Configuration for 1B parameter LLaMA-style Transformer model.
Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
"""
from dataclasses import dataclass
@dataclass
class ModelConfig:
vocab_size: int = 32000
hidden_dim: int = 2048
intermediate_dim: int = 5504 # ~2.7x hidden for SwiGLU (adjusted for param count)
num_layers: int = 22
num_attention_heads: int = 32
num_kv_heads: int = 8 # GQA: 4 query heads per KV head
max_seq_len: int = 2048
rope_theta: float = 10000.0
rms_norm_eps: float = 1e-5
dropout: float = 0.0 # No dropout (modern practice for pretraining)
tie_word_embeddings: bool = False
@property
def head_dim(self) -> int:
return self.hidden_dim // self.num_attention_heads
@property
def num_params_approx(self) -> int:
"""Rough parameter count estimate."""
embed = self.vocab_size * self.hidden_dim
attn_per_layer = (
self.hidden_dim * self.head_dim * self.num_attention_heads + # Q
self.hidden_dim * self.head_dim * self.num_kv_heads + # K
self.hidden_dim * self.head_dim * self.num_kv_heads + # V
self.head_dim * self.num_attention_heads * self.hidden_dim # O
)
ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim # gate + up + down
norm_per_layer = 2 * self.hidden_dim
total = (
embed +
self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
self.hidden_dim + # final norm
(0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
)
return total
@dataclass
class TrainConfig:
# Paths
checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
data_cache_dir: str = "/jfs/deepak-kumar/data"
log_dir: str = "/home/jovyan/training/logs"
# Training
total_tokens: int = 20_000_000_000 # 20B tokens
batch_size_per_gpu: int = 8
gradient_accumulation_steps: int = 8 # effective batch = 8 * 8 * 8 = 512 seqs
max_seq_len: int = 2048
# WSD Schedule
learning_rate: float = 3e-4
min_lr: float = 3e-5
warmup_steps: int = 1000
weight_decay: float = 0.1
beta1: float = 0.9
beta2: float = 0.95
grad_clip: float = 1.0
# Logging
log_interval: int = 10
save_interval: int = 1000
eval_interval: int = 500
# System
num_workers: int = 4
seed: int = 42
bf16: bool = True

View File

@@ -0,0 +1,79 @@
"""
Data pipeline: streams and tokenizes OpenWebText for pretraining.
Packs sequences to max_seq_len for efficiency (no padding waste).
"""
import os
import torch
from torch.utils.data import IterableDataset, DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer
def get_tokenizer(name: str = "mistralai/Mistral-7B-v0.1"):
"""Use Mistral's tokenizer — 32k vocab, BPE, well-trained on diverse data."""
tok = AutoTokenizer.from_pretrained(name, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
return tok
class PackedPretrainDataset(IterableDataset):
"""
Streams text from HuggingFace dataset, tokenizes on the fly,
and packs into fixed-length sequences for maximum GPU utilization.
"""
def __init__(self, tokenizer, max_seq_len: int, split: str = "train", cache_dir: str = None, seed: int = 42):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.split = split
self.cache_dir = cache_dir
self.seed = seed
self.eos_id = tokenizer.eos_token_id
def _token_stream(self):
ds = load_dataset(
"HuggingFaceFW/fineweb-edu",
name="sample-10BT",
split=self.split,
streaming=True,
cache_dir=self.cache_dir,
)
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
for example in ds:
text = example.get("text", "")
if len(text.strip()) < 50:
continue
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
yield from token_ids
yield self.eos_id
def __iter__(self):
buffer = []
for token_id in self._token_stream():
buffer.append(token_id)
if len(buffer) == self.max_seq_len + 1:
input_ids = torch.tensor(buffer[:-1], dtype=torch.long)
labels = torch.tensor(buffer[1:], dtype=torch.long)
yield input_ids, labels
buffer = []
def create_dataloader(tokenizer, config, rank: int = 0, world_size: int = 1, seed_override: int = None):
seed = seed_override if seed_override is not None else config.seed
dataset = PackedPretrainDataset(
tokenizer=tokenizer,
max_seq_len=config.max_seq_len,
split="train",
cache_dir=config.data_cache_dir,
seed=seed + rank,
)
return DataLoader(
dataset,
batch_size=config.batch_size_per_gpu,
num_workers=config.num_workers,
pin_memory=True,
prefetch_factor=4,
)

View File

@@ -0,0 +1,144 @@
"""
DPO data pipeline: loads UltraFeedback preference pairs.
Each example has a prompt + chosen response + rejected response.
We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
chat template, and return them as pairs for DPO training.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
CHAT_TEMPLATE = {
"user_start": "<|user|>\n",
"assistant_start": "<|assistant|>\n",
"turn_end": "\n<|end|>\n",
}
def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
"""Build chat-templated strings for chosen and rejected."""
def build(messages):
text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
for msg in messages:
role = msg.get("role", "assistant")
content = msg.get("content", "").strip()
if role == "assistant":
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
elif role == "user":
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
return text
return build(chosen_msgs), build(rejected_msgs)
class DPODataset(Dataset):
"""
Loads UltraFeedback preference pairs and tokenizes them.
Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
"""
def __init__(self, tokenizer, max_seq_len=2048, split="train",
cache_dir=None, max_samples=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
ds = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned",
split=split,
cache_dir=cache_dir,
)
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"[DPO Data] {len(ds)} preference pairs loaded")
self.examples = []
skipped = 0
for i, row in enumerate(ds):
prompt = row.get("prompt", "")
chosen = row.get("chosen", [])
rejected = row.get("rejected", [])
if not prompt or not chosen or not rejected:
skipped += 1
continue
chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)
chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)
# Truncate if needed
if len(chosen_ids) > max_seq_len + 1:
chosen_ids = chosen_ids[:max_seq_len + 1]
if len(rejected_ids) > max_seq_len + 1:
rejected_ids = rejected_ids[:max_seq_len + 1]
if len(chosen_ids) < 10 or len(rejected_ids) < 10:
skipped += 1
continue
# Find where the prompt ends (first <|assistant|> token)
prompt_end = 0
for j, tid in enumerate(chosen_ids):
if tid == self.assistant_token_id:
prompt_end = j + 2 # skip <|assistant|> and \n
break
self.examples.append({
"chosen_ids": chosen_ids,
"rejected_ids": rejected_ids,
"prompt_len": prompt_end,
})
if (i + 1) % 20000 == 0:
print(f" Processed {i+1} pairs...")
print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
ex = self.examples[idx]
return {
"chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
"rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
"prompt_len": ex["prompt_len"],
}
def dpo_collate_fn(batch, pad_id=0):
"""Pad chosen and rejected sequences separately."""
max_chosen = max(b["chosen_ids"].size(0) for b in batch)
max_rejected = max(b["rejected_ids"].size(0) for b in batch)
chosen_padded = []
rejected_padded = []
prompt_lens = []
for b in batch:
c_pad = max_chosen - b["chosen_ids"].size(0)
r_pad = max_rejected - b["rejected_ids"].size(0)
chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
prompt_lens.append(b["prompt_len"])
return {
"chosen_ids": torch.stack(chosen_padded),
"rejected_ids": torch.stack(rejected_padded),
"prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
}

View File

@@ -0,0 +1,169 @@
"""
SFT data pipeline: loads UltraChat 200K and formats into chat template.
Chat template:
<|user|>
What is gravity?
<|end|>
<|assistant|>
Gravity is a fundamental force...
<|end|>
Labels are shifted left by 1 (standard causal LM), with user turns masked.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
CHAT_TEMPLATE = {
"user_start": "<|user|>\n",
"assistant_start": "<|assistant|>\n",
"turn_end": "\n<|end|>\n",
}
def format_conversation(messages):
"""Convert a list of {role, content} messages into our chat template string."""
text = ""
for msg in messages:
role = msg["role"]
content = msg["content"].strip()
if role == "user":
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
elif role == "assistant":
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
return text
class SFTDataset(Dataset):
"""
Loads UltraChat 200K conversations, tokenizes them, builds shifted labels
with user turns masked so the model only learns to generate assistant responses.
"""
def __init__(self, tokenizer, max_seq_len=2048, split="train_sft", cache_dir=None, max_samples=None):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
print(f"[SFT Data] Loading UltraChat 200K ({split})...")
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, cache_dir=cache_dir)
if max_samples:
ds = ds.select(range(min(max_samples, len(ds))))
print(f"[SFT Data] {len(ds)} conversations loaded")
self.examples = []
skipped = 0
for i, row in enumerate(ds):
messages = row["messages"]
if len(messages) < 2:
skipped += 1
continue
text = format_conversation(messages)
all_ids = tokenizer.encode(text, add_special_tokens=False)
# Need at least max_seq_len+1 for shift, but truncate if longer
if len(all_ids) > max_seq_len + 1:
all_ids = all_ids[:max_seq_len + 1]
if len(all_ids) < 10:
skipped += 1
continue
# Shifted: input = all_ids[:-1], target = all_ids[1:]
input_ids = all_ids[:-1]
target_ids = all_ids[1:]
# Build mask: -100 for user turns, real token id for assistant turns
labels = self._build_shifted_labels(input_ids, target_ids)
self.examples.append((input_ids, labels))
if (i + 1) % 50000 == 0:
print(f" Processed {i+1} conversations...")
print(f"[SFT Data] {len(self.examples)} examples ready, {skipped} skipped")
def _build_shifted_labels(self, input_ids, target_ids):
"""
Walk through the token sequence and track whether we're in a user turn
or assistant turn. Only keep labels for assistant response content.
Masking strategy (applied to the SHIFTED target):
- Everything before and including <|assistant|>\\n: masked
- Assistant response content and <|end|>: TRAIN
- <|user|> and user content until next <|assistant|>: masked
"""
labels = [-100] * len(target_ids)
in_assistant = False
for i, tid in enumerate(input_ids):
if tid == self.assistant_token_id:
# Next token after <|assistant|> is \n, then content starts
in_assistant = True
continue
if tid == self.user_token_id:
in_assistant = False
continue
if in_assistant:
labels[i] = target_ids[i]
# When we hit <|end|> in assistant mode, include it then switch off
if tid == self.end_token_id and in_assistant:
in_assistant = False
return labels
def __len__(self):
return len(self.examples)
def __getitem__(self, idx):
input_ids, labels = self.examples[idx]
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
def sft_collate_fn(batch, pad_id=0):
"""Pad sequences to the same length within a batch."""
input_ids_list, labels_list = zip(*batch)
max_len = max(ids.size(0) for ids in input_ids_list)
padded_inputs = []
padded_labels = []
for ids, lbl in zip(input_ids_list, labels_list):
pad_len = max_len - ids.size(0)
padded_inputs.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
padded_labels.append(torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]))
return torch.stack(padded_inputs), torch.stack(padded_labels)
def create_sft_dataloader(tokenizer, batch_size=4, max_seq_len=2048,
cache_dir=None, max_samples=None, num_workers=4):
dataset = SFTDataset(
tokenizer=tokenizer,
max_seq_len=max_seq_len,
split="train_sft",
cache_dir=cache_dir,
max_samples=max_samples,
)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True,
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
), dataset

View File

@@ -0,0 +1,163 @@
"""
1B Parameter Decoder-Only Transformer — built from scratch.
Techniques:
- RoPE (Rotary Position Embeddings)
- Grouped Query Attention (GQA)
- SwiGLU Feed-Forward
- RMSNorm (pre-norm architecture)
- Flash Attention 2 (via PyTorch SDPA)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import ModelConfig
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
return (x.float() * norm).type_as(x) * self.weight
def precompute_rope_freqs(dim: int, max_seq_len: int, theta: float = 10000.0) -> torch.Tensor:
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
t = torch.arange(max_seq_len, dtype=torch.float32)
freqs = torch.outer(t, freqs)
return torch.polar(torch.ones_like(freqs), freqs) # complex64
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
B, S, H, D = xq.shape
xq_c = torch.view_as_complex(xq.float().reshape(B, S, H, D // 2, 2))
xk_c = torch.view_as_complex(xk.float().reshape(B, S, xk.shape[2], D // 2, 2))
freqs = freqs_cis[:S].clone().unsqueeze(0).unsqueeze(2)
xq_out = torch.view_as_real(xq_c * freqs).flatten(3)
xk_out = torch.view_as_real(xk_c * freqs).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class GroupedQueryAttention(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_kv_heads
self.head_dim = config.head_dim
self.num_groups = self.num_heads // self.num_kv_heads
self.wq = nn.Linear(config.hidden_dim, self.num_heads * self.head_dim, bias=False)
self.wk = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.wv = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
self.wo = nn.Linear(self.num_heads * self.head_dim, config.hidden_dim, bias=False)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
B, S, _ = x.shape
q = self.wq(x).view(B, S, self.num_heads, self.head_dim)
k = self.wk(x).view(B, S, self.num_kv_heads, self.head_dim)
v = self.wv(x).view(B, S, self.num_kv_heads, self.head_dim)
q, k = apply_rope(q, k, freqs_cis)
# Expand KV heads for GQA
if self.num_groups > 1:
k = k.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
k = k.reshape(B, S, self.num_heads, self.head_dim)
v = v.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
v = v.reshape(B, S, self.num_heads, self.head_dim)
# (B, num_heads, S, head_dim) for SDPA
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
out = out.transpose(1, 2).contiguous().view(B, S, -1)
return self.wo(out)
class SwiGLUFFN(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.w_gate = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
self.w_up = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
self.w_down = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
class TransformerBlock(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.attention_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.attention = GroupedQueryAttention(config)
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.ffn = SwiGLUFFN(config)
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
x = x + self.attention(self.attention_norm(x), freqs_cis)
x = x + self.ffn(self.ffn_norm(x))
return x
class Transformer(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
self.norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
# Pre-compute RoPE frequencies
self.register_buffer(
"freqs_cis",
precompute_rope_freqs(config.head_dim, config.max_seq_len * 2, config.rope_theta),
persistent=False,
)
self._init_weights()
def _init_weights(self):
"""Initialize with scaled normal, following GPT-NeoX / LLaMA conventions."""
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=0.02)
# Scale residual projections by 1/sqrt(2*num_layers)
scale = (2 * self.config.num_layers) ** -0.5
for layer in self.layers:
nn.init.normal_(layer.attention.wo.weight, mean=0.0, std=0.02 * scale)
nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale)
def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
B, S = tokens.shape
h = self.tok_embeddings(tokens)
freqs_cis = self.freqs_cis[:S]
for layer in self.layers:
h = layer(h, freqs_cis)
h = self.norm(h)
logits = self.output(h)
loss = None
if targets is not None:
loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
targets.view(-1),
ignore_index=-100,
)
return logits, loss

148
training_code/test_sft.py Normal file
View File

@@ -0,0 +1,148 @@
"""Quick test of model quality with diverse prompts."""
import os, sys, time, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT
DEVICE = "cuda:0"
USER_START = "<|user|>\n"
ASST_START = "<|assistant|>\n"
TURN_END = "\n<|end|>\n"
TEST_PROMPTS = [
"Hi! How are you?",
"What is photosynthesis?",
"Explain gravity to a 5-year-old.",
"Write a short poem about the ocean.",
"What are the three states of matter?",
"How does a computer work?",
"What is the capital of France and why is it famous?",
"Give me 3 tips for learning a new language.",
"What is machine learning in simple terms?",
]
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=256,
temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15):
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE)
generated = []
eos_id = tokenizer.eos_token_id
end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False)
end_id = end_token_ids[0] if end_token_ids else None
user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False)
user_id = user_token_ids[0] if user_token_ids else None
stop_ids = set()
if eos_id is not None:
stop_ids.add(eos_id)
if end_id is not None:
stop_ids.add(end_id)
if user_id is not None:
stop_ids.add(user_id)
for _ in range(max_new_tokens):
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(input_ids)
logits = logits[:, -1, :].float()
if repetition_penalty != 1.0 and generated:
for tid in set(generated):
if logits[0, tid] > 0:
logits[0, tid] /= repetition_penalty
else:
logits[0, tid] *= repetition_penalty
logits = logits / max(temperature, 1e-5)
if top_k > 0:
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < topk_vals[:, -1:]] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p
sorted_logits[remove] = float('-inf')
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
token_id = next_token.item()
if token_id in stop_ids:
break
generated.append(token_id)
input_ids = torch.cat([input_ids, next_token], dim=1)
if input_ids.size(1) > 2048:
break
return tokenizer.decode(generated, skip_special_tokens=True)
def main():
ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT"
print("=" * 70)
print(" " + ckpt_name + " MODEL TEST")
print("=" * 70)
tokenizer = get_tokenizer()
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
config = ModelConfig()
config.vocab_size = len(tokenizer)
model = Transformer(config)
print("")
print("Loading checkpoint: " + CHECKPOINT)
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
step = ckpt.get("step", "?")
del ckpt
model = model.to(DEVICE).bfloat16().eval()
print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")")
mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9
print("GPU memory: " + str(round(mem, 1)) + " GB")
print("-" * 70)
for i, question in enumerate(TEST_PROMPTS, 1):
prompt = USER_START + question + TURN_END + ASST_START
print("")
print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]")
print(" Q: " + question)
t0 = time.time()
response = generate(model, tokenizer, prompt)
dt = time.time() - t0
tokens = len(tokenizer.encode(response, add_special_tokens=False))
response = response.split("<|end|>")[0].split("<|user|>")[0].strip()
print(" A: " + response)
tps = int(tokens / max(dt, 0.01))
print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]")
print("-" * 70)
print("")
print("Done!")
if __name__ == "__main__":
main()

257
training_code/train.py Normal file
View File

@@ -0,0 +1,257 @@
"""
Distributed training script for 1B parameter Transformer.
Launch: torchrun --nproc_per_node=8 train.py
Stack: PyTorch DDP + BF16 autocast + 8x H100 80GB
"""
import os
import sys
import math
import time
import json
import datetime
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig, TrainConfig
from model.transformer import Transformer
from model.data import get_tokenizer, create_dataloader
def get_wsd_lr(step, warmup_steps, total_steps, max_lr, min_lr):
"""Warmup-Stable-Decay: linear warmup -> constant -> cosine decay (last 20%)."""
stable_end = int(total_steps * 0.8)
if step < warmup_steps:
return max_lr * step / max(warmup_steps, 1)
elif step < stable_end:
return max_lr
else:
progress = (step - stable_end) / max(total_steps - stable_end, 1)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def find_latest_checkpoint(checkpoint_dir):
"""Find the latest step_*.pt checkpoint in the directory."""
import glob
pattern = os.path.join(checkpoint_dir, "step_*.pt")
files = glob.glob(pattern)
if not files:
return None, 0
latest = max(files, key=lambda f: int(os.path.basename(f).replace("step_", "").replace(".pt", "")))
step = int(os.path.basename(latest).replace("step_", "").replace(".pt", ""))
return latest, step
def main():
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
model_config = ModelConfig()
train_config = TrainConfig()
eff_batch = train_config.batch_size_per_gpu * world_size * train_config.gradient_accumulation_steps
tokens_per_step = eff_batch * model_config.max_seq_len
total_steps = train_config.total_tokens // tokens_per_step
if rank == 0:
os.makedirs(train_config.log_dir, exist_ok=True)
os.makedirs(train_config.checkpoint_dir, exist_ok=True)
print("=" * 70)
print(f" TRAINING 1B TRANSFORMER FROM SCRATCH")
print(f" Arch: {model_config.num_layers}L / {model_config.hidden_dim}D / "
f"{model_config.num_attention_heads}H / GQA-{model_config.num_kv_heads}KV / "
f"SwiGLU-{model_config.intermediate_dim}")
print(f" Seq: {model_config.max_seq_len} | Vocab: {model_config.vocab_size}")
print(f" GPUs: {world_size}x H100 80GB | Backend: DDP + BF16 autocast")
print(f" Batch: {eff_batch} seqs = {tokens_per_step:,} tok/step")
print(f" Steps: {total_steps:,} | Target: {train_config.total_tokens:,} tokens")
print("=" * 70)
# Tokenizer
tokenizer = get_tokenizer()
# Model
torch.manual_seed(train_config.seed)
model = Transformer(model_config).to(device)
if rank == 0:
n = sum(p.numel() for p in model.parameters())
print(f"[Init] Params: {n:,} ({n/1e9:.3f}B)")
model = DDP(model, device_ids=[local_rank])
# Optimizer
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": train_config.weight_decay},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), fused=True)
if rank == 0:
dp = sum(p.numel() for p in decay_params)
ndp = sum(p.numel() for p in nodecay_params)
print(f"[Init] Optimizer: {dp:,} decay + {ndp:,} no-decay params")
# Resume from checkpoint
resume_step = 0
ckpt_path, ckpt_step = find_latest_checkpoint(train_config.checkpoint_dir)
if ckpt_path is not None:
if rank == 0:
print(f"[Resume] Loading checkpoint: {ckpt_path} (step {ckpt_step})")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
model.module.load_state_dict(ckpt["model"])
optimizer.load_state_dict(ckpt["optimizer"])
resume_step = ckpt["step"]
if rank == 0:
print(f"[Resume] Restored model + optimizer at step {resume_step}, "
f"loss was {ckpt.get('loss', 'N/A')}")
del ckpt
torch.cuda.empty_cache()
else:
if rank == 0:
print("[Init] No checkpoint found, starting from scratch")
# Data — use (seed + resume_step) so resumed runs see different shuffled data
effective_seed = train_config.seed + resume_step
dataloader = create_dataloader(tokenizer, train_config, rank=rank, world_size=world_size,
seed_override=effective_seed)
data_iter = iter(dataloader)
if rank == 0:
print(f"[Init] Dataloader ready (streaming FineWeb-Edu 10BT)")
print(f"[Schedule] WSD: warmup {train_config.warmup_steps} -> "
f"stable {int(total_steps*0.8)} -> decay {total_steps}")
if resume_step > 0:
remaining = total_steps - resume_step
print(f"[Resume] Continuing from step {resume_step}, {remaining:,} steps remaining")
print("-" * 70)
sys.stdout.flush()
# ===== TRAINING LOOP =====
model.train()
global_step = resume_step
running_loss = 0.0
best_loss = float("inf")
tokens_done = resume_step * tokens_per_step
t0 = time.time()
step_t0 = time.time()
log_file = open(os.path.join(train_config.log_dir, "train_log.jsonl"), "a") if rank == 0 else None
while global_step < total_steps:
optimizer.zero_grad(set_to_none=True)
micro_loss = 0.0
for micro in range(train_config.gradient_accumulation_steps):
try:
input_ids, labels = next(data_iter)
except StopIteration:
data_iter = iter(dataloader)
input_ids, labels = next(data_iter)
input_ids = input_ids.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
# BF16 autocast — no scaler needed (BF16 has enough dynamic range)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
_, loss = model(input_ids, labels)
loss = loss / train_config.gradient_accumulation_steps
loss.backward()
micro_loss += loss.item()
# Gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip)
# LR schedule
lr = get_wsd_lr(global_step, train_config.warmup_steps, total_steps,
train_config.learning_rate, train_config.min_lr)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.step()
global_step += 1
running_loss += micro_loss
tokens_done += tokens_per_step
# Log
if global_step % train_config.log_interval == 0:
dt = time.time() - step_t0
tps = (train_config.log_interval * tokens_per_step) / max(dt, 1e-9)
avg = running_loss / train_config.log_interval
elapsed = time.time() - t0
pct = 100.0 * global_step / total_steps
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
if rank == 0:
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
print(
f"[Step {global_step:>6d}/{total_steps}] "
f"loss={avg:.4f} | lr={lr:.2e} | "
f"tok/s={tps:,.0f} | GPU={gpu_mem:.1f}GB | "
f"{pct:.1f}% | ETA={eta/3600:.1f}h",
flush=True,
)
if log_file:
log_file.write(json.dumps({
"step": global_step, "loss": round(avg, 4), "lr": lr,
"tps": round(tps), "tokens": tokens_done,
"gpu_gb": round(gpu_mem, 1), "elapsed_s": round(elapsed, 1),
}) + "\n")
log_file.flush()
if avg < best_loss:
best_loss = avg
running_loss = 0.0
step_t0 = time.time()
# Checkpoint
if global_step % train_config.save_interval == 0:
dist.barrier()
if rank == 0:
ckpt_path = os.path.join(train_config.checkpoint_dir, f"step_{global_step}.pt")
torch.save({
"step": global_step,
"model": model.module.state_dict(),
"optimizer": optimizer.state_dict(),
"loss": avg if global_step % train_config.log_interval == 0 else micro_loss,
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
}, ckpt_path)
print(f" >> Checkpoint: {ckpt_path}", flush=True)
dist.barrier()
# Final
dist.barrier()
if rank == 0:
final_path = os.path.join(train_config.checkpoint_dir, "final.pt")
torch.save({
"step": global_step,
"model": model.module.state_dict(),
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
}, final_path)
total_time = time.time() - t0
print("=" * 70)
print(f" TRAINING COMPLETE")
print(f" Steps: {global_step:,} | Tokens: {tokens_done:,}")
print(f" Time: {total_time/3600:.2f}h | Throughput: {tokens_done/total_time:,.0f} tok/s")
print(f" Best loss: {best_loss:.4f}")
print(f" Final model: {final_path}")
print("=" * 70)
if log_file:
log_file.close()
dist.destroy_process_group()
if __name__ == "__main__":
main()

327
training_code/train_dpo.py Normal file
View File

@@ -0,0 +1,327 @@
"""
DPO (Direct Preference Optimization) training for the 1B Transformer.
Takes the SFT model and aligns it with human preferences using
UltraFeedback preference pairs.
DPO Loss:
L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x)))
Launch: torchrun --nproc_per_node=8 train_dpo.py
"""
import os
import sys
import math
import time
import json
import datetime
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
from model.dpo_data import DPODataset, dpo_collate_fn
# === Config ===
SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo"
LOG_DIR = "/home/jovyan/training/logs"
DATA_CACHE = "/jfs/deepak-kumar/data"
NUM_EPOCHS = 1
BATCH_SIZE_PER_GPU = 2
GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64
MAX_SEQ_LEN = 1024
LEARNING_RATE = 5e-7 # very low LR for DPO
MIN_LR = 1e-7
WARMUP_STEPS = 100
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
BETA = 0.1 # DPO temperature
LOG_INTERVAL = 10
SAVE_INTERVAL = 200
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
if step < warmup_steps:
return max_lr * step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def get_per_token_logps(model, input_ids, prompt_lens):
"""
Compute sum of log probabilities for response tokens only.
input_ids: [B, S] full sequence (prompt + response)
prompt_lens: [B] where response starts
Returns: [B] sum of log probs over response tokens
"""
# Clone input to avoid inplace issues with shared RoPE buffers
inp = input_ids[:, :-1].contiguous()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(inp)
labels = input_ids[:, 1:].contiguous()
log_probs = F.log_softmax(logits.float(), dim=-1)
token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
B, S = token_logps.shape
mask = torch.zeros_like(token_logps)
for b in range(B):
pl = prompt_lens[b].item()
response_start = max(0, pl - 1)
seq_len = (labels[b] != 0).sum().item()
mask[b, response_start:seq_len] = 1.0
return (token_logps * mask).sum(dim=1)
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps, beta=0.1):
"""Compute DPO loss and metrics."""
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
logits = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(logits).mean()
with torch.no_grad():
chosen_better = (chosen_rewards > rejected_rewards).float().mean()
reward_margin = (chosen_rewards - rejected_rewards).mean()
return loss, chosen_better.item(), reward_margin.item()
def main():
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
if rank == 0:
os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
print("=" * 70)
print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER")
print("=" * 70)
tokenizer = get_tokenizer()
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
model_config = ModelConfig()
model_config.vocab_size = len(tokenizer)
if rank == 0:
print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}")
# Policy model (trainable)
policy = Transformer(model_config)
ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False)
policy.load_state_dict(ckpt["model"])
sft_step = ckpt.get("step", 0)
if rank == 0:
print(f"[Init] SFT model loaded (step {sft_step})")
# Reference model (frozen copy)
ref_model = Transformer(model_config)
ref_model.load_state_dict(ckpt["model"])
del ckpt
policy = policy.to(device)
ref_model = ref_model.to(device).bfloat16()
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad = False
policy = DDP(policy, device_ids=[local_rank])
if rank == 0:
n = sum(p.numel() for p in policy.parameters())
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}")
# Dataset
dataset = DPODataset(
tokenizer=tokenizer,
max_seq_len=MAX_SEQ_LEN,
split="train",
cache_dir=DATA_CACHE,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE_PER_GPU,
sampler=sampler,
num_workers=4,
pin_memory=True,
collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id),
)
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
total_steps = steps_per_epoch * NUM_EPOCHS
if rank == 0:
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
print(f"[Init] Dataset: {len(dataset):,} preference pairs")
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
print(f"[Init] Total steps: {total_steps}")
print("-" * 70)
decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad]
nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
policy.train()
global_step = 0
running_loss = 0.0
running_acc = 0.0
running_margin = 0.0
t0 = time.time()
log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None
for epoch in range(NUM_EPOCHS):
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
if rank == 0:
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
while True:
optimizer.zero_grad(set_to_none=True)
batch_loss = 0.0
batch_acc = 0.0
batch_margin = 0.0
valid_micros = 0
for _ in range(GRADIENT_ACCUMULATION):
try:
batch = next(data_iter)
except StopIteration:
break
chosen_ids = batch["chosen_ids"].to(device, non_blocking=True)
rejected_ids = batch["rejected_ids"].to(device, non_blocking=True)
prompt_lens = batch["prompt_lens"].to(device, non_blocking=True)
policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens)
policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens)
with torch.no_grad():
ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens)
ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens)
loss, acc, margin = dpo_loss(
policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps,
beta=BETA,
)
loss = loss / GRADIENT_ACCUMULATION
loss.backward()
batch_loss += loss.item()
batch_acc += acc
batch_margin += margin
valid_micros += 1
if valid_micros == 0:
break
torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP)
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.step()
global_step += 1
running_loss += batch_loss
running_acc += batch_acc / valid_micros
running_margin += batch_margin / valid_micros
if global_step % LOG_INTERVAL == 0:
avg_loss = running_loss / LOG_INTERVAL
avg_acc = running_acc / LOG_INTERVAL
avg_margin = running_margin / LOG_INTERVAL
elapsed = time.time() - t0
pct = 100.0 * global_step / total_steps
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
if rank == 0:
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
print(
f" [Step {global_step:>5d}/{total_steps}] "
f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | "
f"margin={avg_margin:.3f} | lr={lr:.2e} | "
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
flush=True,
)
if log_file:
log_file.write(json.dumps({
"step": global_step, "loss": round(avg_loss, 4),
"accuracy": round(avg_acc, 4),
"reward_margin": round(avg_margin, 4),
"lr": lr, "elapsed_s": round(elapsed, 1),
}) + "\n")
log_file.flush()
running_loss = 0.0
running_acc = 0.0
running_margin = 0.0
if global_step % SAVE_INTERVAL == 0:
dist.barrier()
if rank == 0:
path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt")
torch.save({
"step": global_step,
"model": policy.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": model_config.vocab_size,
}, path)
print(f" >> Checkpoint: {path}", flush=True)
dist.barrier()
# Final save
dist.barrier()
if rank == 0:
final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt")
torch.save({
"step": global_step,
"model": policy.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": model_config.vocab_size,
}, final_path)
total_time = time.time() - t0
print("=" * 70)
print(f" DPO COMPLETE")
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
print(f" Time: {total_time/60:.1f} minutes")
print(f" Final model: {final_path}")
print("=" * 70)
if log_file:
log_file.close()
dist.destroy_process_group()
if __name__ == "__main__":
main()

272
training_code/train_sft.py Normal file
View File

@@ -0,0 +1,272 @@
"""
SFT (Supervised Fine-Tuning) script for the 1B Transformer.
Takes the pretrained base model and fine-tunes it on instruction-response
conversations from UltraChat 200K.
Launch: torchrun --nproc_per_node=8 train_sft.py
"""
import os
import sys
import math
import time
import json
import datetime
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
from model.sft_data import SFTDataset, sft_collate_fn
# === Config ===
BASE_CHECKPOINT = "/jfs/deepak-kumar/checkpoints/step_19000.pt"
SFT_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_sft"
LOG_DIR = "/home/jovyan/training/logs"
DATA_CACHE = "/jfs/deepak-kumar/data"
NUM_EPOCHS = 2
BATCH_SIZE_PER_GPU = 4
GRADIENT_ACCUMULATION = 4 # effective batch = 4 * 8 * 4 = 128
MAX_SEQ_LEN = 2048
LEARNING_RATE = 2e-5 # much lower than pretraining — we're fine-tuning
MIN_LR = 2e-6
WARMUP_STEPS = 200
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
LOG_INTERVAL = 10
SAVE_INTERVAL = 500
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
if step < warmup_steps:
return max_lr * step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def main():
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
if rank == 0:
os.makedirs(SFT_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
print("=" * 70)
print(" SFT: INSTRUCTION FINE-TUNING 1B TRANSFORMER")
print("=" * 70)
# Tokenizer
tokenizer = get_tokenizer()
# Load base model
model_config = ModelConfig()
torch.manual_seed(42)
model = Transformer(model_config)
if rank == 0:
print(f"[Init] Loading base model from {BASE_CHECKPOINT}")
ckpt = torch.load(BASE_CHECKPOINT, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt["model"])
base_step = ckpt.get("step", 0)
base_loss = ckpt.get("loss", "?")
if rank == 0:
print(f"[Init] Base model: step={base_step}, pretrain_loss={base_loss}")
del ckpt
# Add chat tokens to embedding — expand vocab if needed
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
new_vocab_size = len(tokenizer)
if new_vocab_size > model_config.vocab_size:
if rank == 0:
print(f"[Init] Expanding vocab: {model_config.vocab_size} -> {new_vocab_size}")
old_emb_weight = model.tok_embeddings.weight.data
model.tok_embeddings = torch.nn.Embedding(new_vocab_size, model_config.hidden_dim)
model.tok_embeddings.weight.data[:model_config.vocab_size] = old_emb_weight
# Init new token embeddings as mean of existing (better than random)
mean_emb = old_emb_weight.mean(dim=0)
for i in range(model_config.vocab_size, new_vocab_size):
model.tok_embeddings.weight.data[i] = mean_emb
old_output_weight = model.output.weight.data
model.output = torch.nn.Linear(model_config.hidden_dim, new_vocab_size, bias=False)
model.output.weight.data[:model_config.vocab_size] = old_output_weight
model.config.vocab_size = new_vocab_size
model = model.to(device)
model = DDP(model, device_ids=[local_rank])
if rank == 0:
n = sum(p.numel() for p in model.parameters())
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
# Dataset (only load on each process)
dataset = SFTDataset(
tokenizer=tokenizer,
max_seq_len=MAX_SEQ_LEN,
split="train_sft",
cache_dir=DATA_CACHE,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE_PER_GPU,
sampler=sampler,
num_workers=4,
pin_memory=True,
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
)
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
total_steps = steps_per_epoch * NUM_EPOCHS
if rank == 0:
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
print(f"[Init] Dataset: {len(dataset):,} examples")
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
print(f"[Init] Total steps: {total_steps} | Epochs: {NUM_EPOCHS}")
print(f"[Init] LR: {LEARNING_RATE}{MIN_LR} (cosine)")
print("-" * 70)
# Optimizer — lower LR for fine-tuning
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
# Training
model.train()
global_step = 0
running_loss = 0.0
t0 = time.time()
step_t0 = time.time()
log_file = open(os.path.join(LOG_DIR, "sft_log.jsonl"), "w") if rank == 0 else None
for epoch in range(NUM_EPOCHS):
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
micro_step = 0
if rank == 0:
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
while True:
optimizer.zero_grad(set_to_none=True)
batch_loss = 0.0
for _ in range(GRADIENT_ACCUMULATION):
try:
input_ids, labels = next(data_iter)
except StopIteration:
break
input_ids = input_ids.to(device, non_blocking=True)
labels = labels.to(device, non_blocking=True)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
_, loss = model(input_ids, labels)
loss = loss / GRADIENT_ACCUMULATION
loss.backward()
batch_loss += loss.item()
micro_step += 1
if batch_loss == 0:
break
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.step()
global_step += 1
running_loss += batch_loss
if global_step % LOG_INTERVAL == 0:
dt = time.time() - step_t0
avg = running_loss / LOG_INTERVAL
elapsed = time.time() - t0
pct = 100.0 * global_step / total_steps
if rank == 0:
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
print(
f" [Step {global_step:>5d}/{total_steps}] "
f"loss={avg:.4f} | lr={lr:.2e} | "
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
flush=True,
)
if log_file:
log_file.write(json.dumps({
"step": global_step, "epoch": epoch + 1,
"loss": round(avg, 4), "lr": lr,
"elapsed_s": round(elapsed, 1),
}) + "\n")
log_file.flush()
running_loss = 0.0
step_t0 = time.time()
if global_step % SAVE_INTERVAL == 0:
dist.barrier()
if rank == 0:
path = os.path.join(SFT_CHECKPOINT_DIR, f"sft_step_{global_step}.pt")
torch.save({
"step": global_step,
"model": model.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": new_vocab_size,
}, path)
print(f" >> Checkpoint: {path}", flush=True)
dist.barrier()
# Final save
dist.barrier()
if rank == 0:
final_path = os.path.join(SFT_CHECKPOINT_DIR, "sft_final.pt")
torch.save({
"step": global_step,
"model": model.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": new_vocab_size,
}, final_path)
total_time = time.time() - t0
print("=" * 70)
print(f" SFT COMPLETE")
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
print(f" Time: {total_time/60:.1f} minutes")
print(f" Final model: {final_path}")
print("=" * 70)
if log_file:
log_file.close()
dist.destroy_process_group()
if __name__ == "__main__":
main()