初始化项目,由ModelHub XC社区提供模型
Model: divakar-yadav/transformer-1b-chat Source: Original Platform
This commit is contained in:
318
training_code/chat.py
Normal file
318
training_code/chat.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Interactive chat with the 1B Transformer.
|
||||
Runs in an infinite conversation loop from the terminal.
|
||||
|
||||
Usage:
|
||||
python chat.py # auto-find latest checkpoint
|
||||
python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import readline # enables arrow keys and history in input()
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
|
||||
def find_latest_checkpoint():
|
||||
"""Look for DPO > SFT > pretrained checkpoint."""
|
||||
dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo"
|
||||
sft_dir = "/jfs/deepak-kumar/checkpoints_sft"
|
||||
pt_dir = "/jfs/deepak-kumar/checkpoints"
|
||||
|
||||
# Prefer DPO final
|
||||
dpo_final = os.path.join(dpo_dir, "dpo_final.pt")
|
||||
if os.path.exists(dpo_final):
|
||||
return dpo_final, True
|
||||
|
||||
dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt"))
|
||||
if dpo_files:
|
||||
return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True
|
||||
|
||||
# Then SFT
|
||||
sft_final = os.path.join(sft_dir, "sft_final.pt")
|
||||
if os.path.exists(sft_final):
|
||||
return sft_final, True
|
||||
|
||||
sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt"))
|
||||
if sft_files:
|
||||
return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True
|
||||
|
||||
# Fall back to pretrained
|
||||
pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt"))
|
||||
if pt_files:
|
||||
return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def load_model(checkpoint_path, tokenizer, device="cuda:0"):
|
||||
config = ModelConfig()
|
||||
model = Transformer(config)
|
||||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
|
||||
# Handle expanded vocab from SFT
|
||||
saved_vocab = ckpt.get("vocab_size", config.vocab_size)
|
||||
if saved_vocab > config.vocab_size:
|
||||
config.vocab_size = saved_vocab
|
||||
model = Transformer(config)
|
||||
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model = model.to(device).bfloat16().eval()
|
||||
step = ckpt.get("step", "?")
|
||||
loss = ckpt.get("loss", "?")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
return model, config, step, loss
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_stream(model, tokenizer, prompt, max_new_tokens=512,
|
||||
temperature=0.8, top_k=50, top_p=0.9,
|
||||
repetition_penalty=1.15, device="cuda:0",
|
||||
stop_token_ids=None):
|
||||
"""Generate tokens one at a time, yielding each for streaming output."""
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
generated_ids = []
|
||||
prev_decoded_len = 0
|
||||
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = set()
|
||||
else:
|
||||
stop_token_ids = set(stop_token_ids)
|
||||
stop_token_ids.add(tokenizer.eos_token_id)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if input_ids.shape[1] >= model.config.max_seq_len:
|
||||
break
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty != 1.0 and generated_ids:
|
||||
prev_tokens = torch.tensor(generated_ids, device=device).unique()
|
||||
for token_id in prev_tokens:
|
||||
if logits[0, token_id] > 0:
|
||||
logits[0, token_id] /= repetition_penalty
|
||||
else:
|
||||
logits[0, token_id] *= repetition_penalty
|
||||
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, top_k)
|
||||
logits[logits < topk_vals[:, -1:]] = float("-inf")
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
||||
sorted_logits[mask] = float("-inf")
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
token_id = next_token.item()
|
||||
|
||||
# Stop on any stop token (EOS, <|end|>, <|user|>)
|
||||
if token_id in stop_token_ids:
|
||||
break
|
||||
|
||||
generated_ids.append(token_id)
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
new_text = full_decoded[prev_decoded_len:]
|
||||
prev_decoded_len = len(full_decoded)
|
||||
yield new_text
|
||||
|
||||
return
|
||||
|
||||
|
||||
def print_banner(step, loss, device):
|
||||
print("\033[1;36m") # cyan bold
|
||||
print("=" * 60)
|
||||
print(" 1B TRANSFORMER — Interactive Chat")
|
||||
print("=" * 60)
|
||||
print(f"\033[0m Checkpoint : step {step}")
|
||||
print(f" Loss : {loss}")
|
||||
print(f" Device : {device}")
|
||||
print(f" Parameters : 1.106B")
|
||||
print()
|
||||
print(" \033[90mCommands:\033[0m")
|
||||
print(" \033[33m/quit\033[0m — exit")
|
||||
print(" \033[33m/clear\033[0m — clear conversation context")
|
||||
print(" \033[33m/temp N\033[0m — set temperature (default 0.8)")
|
||||
print(" \033[33m/tokens N\033[0m — set max tokens (default 512)")
|
||||
print(" \033[33m/topp N\033[0m — set top-p (default 0.9)")
|
||||
print(" \033[33m/topk N\033[0m — set top-k (default 50)")
|
||||
print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)")
|
||||
print()
|
||||
print("\033[90m" + "─" * 60 + "\033[0m")
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda:0"
|
||||
|
||||
is_sft = False
|
||||
if len(sys.argv) > 1:
|
||||
checkpoint = sys.argv[1]
|
||||
is_sft = "sft" in checkpoint.lower()
|
||||
else:
|
||||
result = find_latest_checkpoint()
|
||||
if result[0] is None:
|
||||
print("No checkpoint found!")
|
||||
sys.exit(1)
|
||||
checkpoint, is_sft = result
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Add chat tokens for SFT models
|
||||
if is_sft:
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
print(f"\n Loading model from {checkpoint}...")
|
||||
print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}")
|
||||
model, config, step, loss = load_model(checkpoint, tokenizer, device)
|
||||
print(f" Model loaded!\n")
|
||||
|
||||
print_banner(step, loss, device)
|
||||
if is_sft:
|
||||
print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n")
|
||||
|
||||
# Settings
|
||||
temperature = 0.7 if is_sft else 0.8
|
||||
max_tokens = 512
|
||||
top_p = 0.9
|
||||
top_k = 50
|
||||
rep_penalty = 1.15
|
||||
context = ""
|
||||
|
||||
# Chat template tokens for SFT
|
||||
USER_START = "<|user|>\n"
|
||||
ASST_START = "<|assistant|>\n"
|
||||
TURN_END = "\n<|end|>\n"
|
||||
|
||||
# Build stop token IDs for generation
|
||||
sft_stop_ids = []
|
||||
if is_sft:
|
||||
vocab = tokenizer.get_vocab()
|
||||
for tok_str in ["<|end|>", "<|user|>"]:
|
||||
if tok_str in vocab:
|
||||
sft_stop_ids.append(vocab[tok_str])
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n\033[1;32mYou:\033[0m ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n\nGoodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle commands
|
||||
if user_input.startswith("/"):
|
||||
cmd = user_input.lower().split()
|
||||
if cmd[0] == "/quit":
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif cmd[0] == "/clear":
|
||||
context = ""
|
||||
print("\033[90m [Context cleared]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/temp" and len(cmd) > 1:
|
||||
temperature = float(cmd[1])
|
||||
print(f"\033[90m [Temperature set to {temperature}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/tokens" and len(cmd) > 1:
|
||||
max_tokens = int(cmd[1])
|
||||
print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/topp" and len(cmd) > 1:
|
||||
top_p = float(cmd[1])
|
||||
print(f"\033[90m [Top-p set to {top_p}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/topk" and len(cmd) > 1:
|
||||
top_k = int(cmd[1])
|
||||
print(f"\033[90m [Top-k set to {top_k}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/rep" and len(cmd) > 1:
|
||||
rep_penalty = float(cmd[1])
|
||||
print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m")
|
||||
continue
|
||||
else:
|
||||
print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m")
|
||||
continue
|
||||
|
||||
# Build prompt
|
||||
if is_sft:
|
||||
prompt = context + USER_START + user_input + TURN_END + ASST_START
|
||||
else:
|
||||
if context:
|
||||
prompt = context + "\n" + user_input
|
||||
else:
|
||||
prompt = user_input
|
||||
|
||||
# Trim context if too long
|
||||
while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens:
|
||||
if is_sft:
|
||||
parts = context.split(TURN_END)
|
||||
if len(parts) <= 2:
|
||||
break
|
||||
context = TURN_END.join(parts[2:])
|
||||
prompt = context + USER_START + user_input + TURN_END + ASST_START
|
||||
else:
|
||||
lines = prompt.split("\n")
|
||||
if len(lines) <= 2:
|
||||
break
|
||||
prompt = "\n".join(lines[1:])
|
||||
|
||||
# Generate with streaming
|
||||
print("\033[1;34mModel:\033[0m ", end="", flush=True)
|
||||
t0 = time.time()
|
||||
full_response = ""
|
||||
token_count = 0
|
||||
|
||||
for token_text in generate_stream(
|
||||
model, tokenizer, prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=rep_penalty,
|
||||
device=device,
|
||||
stop_token_ids=sft_stop_ids if is_sft else None,
|
||||
):
|
||||
print(token_text, end="", flush=True)
|
||||
full_response += token_text
|
||||
token_count += 1
|
||||
|
||||
elapsed = time.time() - t0
|
||||
tps = token_count / max(elapsed, 1e-9)
|
||||
print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m")
|
||||
|
||||
# Append to context for multi-turn
|
||||
if is_sft:
|
||||
context = (context + USER_START + user_input + TURN_END +
|
||||
ASST_START + full_response.strip() + TURN_END)
|
||||
else:
|
||||
context = prompt + full_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
167
training_code/export_to_hf.py
Normal file
167
training_code/export_to_hf.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Export the trained model to HuggingFace-compatible format.
|
||||
|
||||
Creates:
|
||||
- model.safetensors (weights)
|
||||
- config.json (architecture config)
|
||||
- generation_config.json
|
||||
- tokenizer.json, tokenizer_config.json, special_tokens_map.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from safetensors.torch import save_file
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
CHECKPOINT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
|
||||
OUTPUT_DIR = "/home/jovyan/training/hf_model"
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print("=" * 60)
|
||||
print(" EXPORTING MODEL TO HUGGING FACE FORMAT")
|
||||
print("=" * 60)
|
||||
|
||||
# --- 1. Load model ---
|
||||
print("\n[1/4] Loading checkpoint...")
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
model_config = ModelConfig()
|
||||
model_config.vocab_size = len(tokenizer)
|
||||
|
||||
model = Transformer(model_config)
|
||||
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
step = ckpt.get("step", 0)
|
||||
del ckpt
|
||||
print(f" Loaded DPO model (step {step}, vocab {model_config.vocab_size})")
|
||||
|
||||
# --- 2. Convert state dict keys to HF-style naming ---
|
||||
print("\n[2/4] Converting weights to safetensors...")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
hf_state = OrderedDict()
|
||||
|
||||
KEY_MAP = {
|
||||
"tok_embeddings.weight": "model.embed_tokens.weight",
|
||||
"norm.weight": "model.norm.weight",
|
||||
"output.weight": "lm_head.weight",
|
||||
}
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
if key in KEY_MAP:
|
||||
hf_state[KEY_MAP[key]] = tensor
|
||||
continue
|
||||
|
||||
if key.startswith("layers."):
|
||||
parts = key.split(".")
|
||||
layer_idx = parts[1]
|
||||
rest = ".".join(parts[2:])
|
||||
|
||||
layer_map = {
|
||||
"attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight",
|
||||
"ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
|
||||
"attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
|
||||
"attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight",
|
||||
"attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
|
||||
"attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight",
|
||||
"ffn.w_gate.weight": f"model.layers.{layer_idx}.mlp.gate_proj.weight",
|
||||
"ffn.w_up.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight",
|
||||
"ffn.w_down.weight": f"model.layers.{layer_idx}.mlp.down_proj.weight",
|
||||
}
|
||||
|
||||
if rest in layer_map:
|
||||
hf_state[layer_map[rest]] = tensor
|
||||
else:
|
||||
print(f" WARNING: unmapped key {key}")
|
||||
hf_state[key] = tensor
|
||||
elif key == "freqs_cis":
|
||||
continue
|
||||
else:
|
||||
print(f" WARNING: unmapped key {key}")
|
||||
hf_state[key] = tensor
|
||||
|
||||
# Convert all to bfloat16 for storage
|
||||
for k in hf_state:
|
||||
if hf_state[k].dtype == torch.float32:
|
||||
hf_state[k] = hf_state[k].to(torch.bfloat16)
|
||||
|
||||
safetensors_path = os.path.join(OUTPUT_DIR, "model.safetensors")
|
||||
save_file(hf_state, safetensors_path)
|
||||
size_gb = os.path.getsize(safetensors_path) / 1e9
|
||||
print(f" Saved {len(hf_state)} tensors -> {safetensors_path} ({size_gb:.2f} GB)")
|
||||
|
||||
# --- 3. Write config files ---
|
||||
print("\n[3/4] Writing config files...")
|
||||
|
||||
config_json = {
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"model_type": "llama",
|
||||
"vocab_size": model_config.vocab_size,
|
||||
"hidden_size": model_config.hidden_dim,
|
||||
"intermediate_size": model_config.intermediate_dim,
|
||||
"num_hidden_layers": model_config.num_layers,
|
||||
"num_attention_heads": model_config.num_attention_heads,
|
||||
"num_key_value_heads": model_config.num_kv_heads,
|
||||
"max_position_embeddings": model_config.max_seq_len,
|
||||
"rope_theta": model_config.rope_theta,
|
||||
"rms_norm_eps": model_config.rms_norm_eps,
|
||||
"hidden_act": "silu",
|
||||
"initializer_range": 0.02,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": True,
|
||||
"bos_token_id": tokenizer.bos_token_id,
|
||||
"eos_token_id": tokenizer.eos_token_id,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
}
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
|
||||
json.dump(config_json, f, indent=2)
|
||||
print(" config.json")
|
||||
|
||||
gen_config = {
|
||||
"bos_token_id": tokenizer.bos_token_id,
|
||||
"eos_token_id": tokenizer.eos_token_id,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"do_sample": True,
|
||||
"temperature": 0.7,
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.15,
|
||||
"max_new_tokens": 512,
|
||||
"transformers_version": "4.40.0",
|
||||
}
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR, "generation_config.json"), "w") as f:
|
||||
json.dump(gen_config, f, indent=2)
|
||||
print(" generation_config.json")
|
||||
|
||||
# --- 4. Export tokenizer ---
|
||||
print("\n[4/4] Exporting tokenizer...")
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
print(" Tokenizer files saved")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" EXPORT COMPLETE -> " + OUTPUT_DIR)
|
||||
print("=" * 60)
|
||||
print("\nFiles:")
|
||||
for f in sorted(os.listdir(OUTPUT_DIR)):
|
||||
size = os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
||||
if size > 1e6:
|
||||
print(f" {f:40s} {size/1e6:.1f} MB")
|
||||
else:
|
||||
print(f" {f:40s} {size/1e3:.1f} KB")
|
||||
134
training_code/inference.py
Normal file
134
training_code/inference.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Inference script for the 1B Transformer — Single GPU.
|
||||
|
||||
Usage:
|
||||
python inference.py # auto-finds latest checkpoint
|
||||
python inference.py /path/to/checkpoint.pt # specific checkpoint
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
|
||||
def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"):
|
||||
files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
|
||||
if not files:
|
||||
final = os.path.join(checkpoint_dir, "final.pt")
|
||||
return final if os.path.exists(final) else None
|
||||
return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0]))
|
||||
|
||||
|
||||
def load_model(checkpoint_path, device="cuda:0"):
|
||||
config = ModelConfig()
|
||||
model = Transformer(config)
|
||||
|
||||
print(f"Loading checkpoint: {checkpoint_path}")
|
||||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model = model.to(device).bfloat16().eval()
|
||||
|
||||
step = ckpt.get("step", "?")
|
||||
loss = ckpt.get("loss", "?")
|
||||
print(f" Step: {step} | Loss: {loss}")
|
||||
print(f" Params: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" Device: {device}")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
return model, config
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model, tokenizer, prompt, max_new_tokens=200,
|
||||
temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"):
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
t0 = time.time()
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if input_ids.shape[1] >= model.config.max_seq_len:
|
||||
break
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :] / temperature
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, top_k)
|
||||
logits[logits < topk_vals[:, -1:]] = float("-inf")
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
||||
sorted_logits[mask] = float("-inf")
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
if next_token.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt))
|
||||
tok_per_sec = gen_tokens / max(elapsed, 1e-9)
|
||||
|
||||
text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
||||
return text, gen_tokens, tok_per_sec
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda:0"
|
||||
if len(sys.argv) > 1:
|
||||
checkpoint = sys.argv[1]
|
||||
else:
|
||||
checkpoint = find_latest_checkpoint()
|
||||
if checkpoint is None:
|
||||
print("No checkpoint found!")
|
||||
sys.exit(1)
|
||||
|
||||
model, config = load_model(checkpoint, device)
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
prompts = [
|
||||
"The meaning of life is",
|
||||
"In machine learning, a neural network",
|
||||
"The capital of France is",
|
||||
"Once upon a time, there was a",
|
||||
"To solve a quadratic equation, you need to",
|
||||
"The theory of relativity explains that",
|
||||
"Python is a programming language that",
|
||||
"The sun rises in the east and",
|
||||
]
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" INFERENCE — 1B Transformer (Single GPU)")
|
||||
print("=" * 70)
|
||||
|
||||
for prompt in prompts:
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"PROMPT: {prompt}")
|
||||
print(f"{'─' * 60}")
|
||||
text, n_tok, tps = generate(model, tokenizer, prompt,
|
||||
max_new_tokens=150, temperature=0.8,
|
||||
top_k=50, device=device)
|
||||
generated = text[len(prompt):]
|
||||
print(f"OUTPUT:{generated}")
|
||||
print(f" [{n_tok} tokens, {tps:.1f} tok/s]")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
training_code/model/__init__.py
Normal file
2
training_code/model/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config import ModelConfig, TrainConfig
|
||||
from .transformer import Transformer
|
||||
78
training_code/model/config.py
Normal file
78
training_code/model/config.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Configuration for 1B parameter LLaMA-style Transformer model.
|
||||
Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
vocab_size: int = 32000
|
||||
hidden_dim: int = 2048
|
||||
intermediate_dim: int = 5504 # ~2.7x hidden for SwiGLU (adjusted for param count)
|
||||
num_layers: int = 22
|
||||
num_attention_heads: int = 32
|
||||
num_kv_heads: int = 8 # GQA: 4 query heads per KV head
|
||||
max_seq_len: int = 2048
|
||||
rope_theta: float = 10000.0
|
||||
rms_norm_eps: float = 1e-5
|
||||
dropout: float = 0.0 # No dropout (modern practice for pretraining)
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.hidden_dim // self.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_params_approx(self) -> int:
|
||||
"""Rough parameter count estimate."""
|
||||
embed = self.vocab_size * self.hidden_dim
|
||||
attn_per_layer = (
|
||||
self.hidden_dim * self.head_dim * self.num_attention_heads + # Q
|
||||
self.hidden_dim * self.head_dim * self.num_kv_heads + # K
|
||||
self.hidden_dim * self.head_dim * self.num_kv_heads + # V
|
||||
self.head_dim * self.num_attention_heads * self.hidden_dim # O
|
||||
)
|
||||
ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim # gate + up + down
|
||||
norm_per_layer = 2 * self.hidden_dim
|
||||
total = (
|
||||
embed +
|
||||
self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
|
||||
self.hidden_dim + # final norm
|
||||
(0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
|
||||
)
|
||||
return total
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
# Paths
|
||||
checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
|
||||
data_cache_dir: str = "/jfs/deepak-kumar/data"
|
||||
log_dir: str = "/home/jovyan/training/logs"
|
||||
|
||||
# Training
|
||||
total_tokens: int = 20_000_000_000 # 20B tokens
|
||||
batch_size_per_gpu: int = 8
|
||||
gradient_accumulation_steps: int = 8 # effective batch = 8 * 8 * 8 = 512 seqs
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# WSD Schedule
|
||||
learning_rate: float = 3e-4
|
||||
min_lr: float = 3e-5
|
||||
warmup_steps: int = 1000
|
||||
weight_decay: float = 0.1
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.95
|
||||
grad_clip: float = 1.0
|
||||
|
||||
# Logging
|
||||
log_interval: int = 10
|
||||
save_interval: int = 1000
|
||||
eval_interval: int = 500
|
||||
|
||||
# System
|
||||
num_workers: int = 4
|
||||
seed: int = 42
|
||||
bf16: bool = True
|
||||
79
training_code/model/data.py
Normal file
79
training_code/model/data.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Data pipeline: streams and tokenizes OpenWebText for pretraining.
|
||||
Packs sequences to max_seq_len for efficiency (no padding waste).
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_tokenizer(name: str = "mistralai/Mistral-7B-v0.1"):
|
||||
"""Use Mistral's tokenizer — 32k vocab, BPE, well-trained on diverse data."""
|
||||
tok = AutoTokenizer.from_pretrained(name, use_fast=True)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
return tok
|
||||
|
||||
|
||||
class PackedPretrainDataset(IterableDataset):
|
||||
"""
|
||||
Streams text from HuggingFace dataset, tokenizes on the fly,
|
||||
and packs into fixed-length sequences for maximum GPU utilization.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len: int, split: str = "train", cache_dir: str = None, seed: int = 42):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
self.split = split
|
||||
self.cache_dir = cache_dir
|
||||
self.seed = seed
|
||||
self.eos_id = tokenizer.eos_token_id
|
||||
|
||||
def _token_stream(self):
|
||||
ds = load_dataset(
|
||||
"HuggingFaceFW/fineweb-edu",
|
||||
name="sample-10BT",
|
||||
split=self.split,
|
||||
streaming=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
||||
|
||||
for example in ds:
|
||||
text = example.get("text", "")
|
||||
if len(text.strip()) < 50:
|
||||
continue
|
||||
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
yield from token_ids
|
||||
yield self.eos_id
|
||||
|
||||
def __iter__(self):
|
||||
buffer = []
|
||||
for token_id in self._token_stream():
|
||||
buffer.append(token_id)
|
||||
if len(buffer) == self.max_seq_len + 1:
|
||||
input_ids = torch.tensor(buffer[:-1], dtype=torch.long)
|
||||
labels = torch.tensor(buffer[1:], dtype=torch.long)
|
||||
yield input_ids, labels
|
||||
buffer = []
|
||||
|
||||
|
||||
def create_dataloader(tokenizer, config, rank: int = 0, world_size: int = 1, seed_override: int = None):
|
||||
seed = seed_override if seed_override is not None else config.seed
|
||||
dataset = PackedPretrainDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=config.max_seq_len,
|
||||
split="train",
|
||||
cache_dir=config.data_cache_dir,
|
||||
seed=seed + rank,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=config.batch_size_per_gpu,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4,
|
||||
)
|
||||
144
training_code/model/dpo_data.py
Normal file
144
training_code/model/dpo_data.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
DPO data pipeline: loads UltraFeedback preference pairs.
|
||||
|
||||
Each example has a prompt + chosen response + rejected response.
|
||||
We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
|
||||
chat template, and return them as pairs for DPO training.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
CHAT_TEMPLATE = {
|
||||
"user_start": "<|user|>\n",
|
||||
"assistant_start": "<|assistant|>\n",
|
||||
"turn_end": "\n<|end|>\n",
|
||||
}
|
||||
|
||||
|
||||
def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
|
||||
"""Build chat-templated strings for chosen and rejected."""
|
||||
def build(messages):
|
||||
text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
|
||||
for msg in messages:
|
||||
role = msg.get("role", "assistant")
|
||||
content = msg.get("content", "").strip()
|
||||
if role == "assistant":
|
||||
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
elif role == "user":
|
||||
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
return text
|
||||
|
||||
return build(chosen_msgs), build(rejected_msgs)
|
||||
|
||||
|
||||
class DPODataset(Dataset):
|
||||
"""
|
||||
Loads UltraFeedback preference pairs and tokenizes them.
|
||||
Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len=2048, split="train",
|
||||
cache_dir=None, max_samples=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
|
||||
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
|
||||
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
|
||||
|
||||
print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
|
||||
ds = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||
split=split,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if max_samples:
|
||||
ds = ds.select(range(min(max_samples, len(ds))))
|
||||
print(f"[DPO Data] {len(ds)} preference pairs loaded")
|
||||
|
||||
self.examples = []
|
||||
skipped = 0
|
||||
for i, row in enumerate(ds):
|
||||
prompt = row.get("prompt", "")
|
||||
chosen = row.get("chosen", [])
|
||||
rejected = row.get("rejected", [])
|
||||
|
||||
if not prompt or not chosen or not rejected:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)
|
||||
|
||||
chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
|
||||
rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)
|
||||
|
||||
# Truncate if needed
|
||||
if len(chosen_ids) > max_seq_len + 1:
|
||||
chosen_ids = chosen_ids[:max_seq_len + 1]
|
||||
if len(rejected_ids) > max_seq_len + 1:
|
||||
rejected_ids = rejected_ids[:max_seq_len + 1]
|
||||
|
||||
if len(chosen_ids) < 10 or len(rejected_ids) < 10:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Find where the prompt ends (first <|assistant|> token)
|
||||
prompt_end = 0
|
||||
for j, tid in enumerate(chosen_ids):
|
||||
if tid == self.assistant_token_id:
|
||||
prompt_end = j + 2 # skip <|assistant|> and \n
|
||||
break
|
||||
|
||||
self.examples.append({
|
||||
"chosen_ids": chosen_ids,
|
||||
"rejected_ids": rejected_ids,
|
||||
"prompt_len": prompt_end,
|
||||
})
|
||||
|
||||
if (i + 1) % 20000 == 0:
|
||||
print(f" Processed {i+1} pairs...")
|
||||
|
||||
print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ex = self.examples[idx]
|
||||
return {
|
||||
"chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
|
||||
"rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
|
||||
"prompt_len": ex["prompt_len"],
|
||||
}
|
||||
|
||||
|
||||
def dpo_collate_fn(batch, pad_id=0):
|
||||
"""Pad chosen and rejected sequences separately."""
|
||||
max_chosen = max(b["chosen_ids"].size(0) for b in batch)
|
||||
max_rejected = max(b["rejected_ids"].size(0) for b in batch)
|
||||
|
||||
chosen_padded = []
|
||||
rejected_padded = []
|
||||
prompt_lens = []
|
||||
|
||||
for b in batch:
|
||||
c_pad = max_chosen - b["chosen_ids"].size(0)
|
||||
r_pad = max_rejected - b["rejected_ids"].size(0)
|
||||
chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
|
||||
rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
|
||||
prompt_lens.append(b["prompt_len"])
|
||||
|
||||
return {
|
||||
"chosen_ids": torch.stack(chosen_padded),
|
||||
"rejected_ids": torch.stack(rejected_padded),
|
||||
"prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
|
||||
}
|
||||
169
training_code/model/sft_data.py
Normal file
169
training_code/model/sft_data.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
SFT data pipeline: loads UltraChat 200K and formats into chat template.
|
||||
|
||||
Chat template:
|
||||
<|user|>
|
||||
What is gravity?
|
||||
<|end|>
|
||||
<|assistant|>
|
||||
Gravity is a fundamental force...
|
||||
<|end|>
|
||||
|
||||
Labels are shifted left by 1 (standard causal LM), with user turns masked.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
CHAT_TEMPLATE = {
|
||||
"user_start": "<|user|>\n",
|
||||
"assistant_start": "<|assistant|>\n",
|
||||
"turn_end": "\n<|end|>\n",
|
||||
}
|
||||
|
||||
|
||||
def format_conversation(messages):
|
||||
"""Convert a list of {role, content} messages into our chat template string."""
|
||||
text = ""
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"].strip()
|
||||
if role == "user":
|
||||
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
elif role == "assistant":
|
||||
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
return text
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Loads UltraChat 200K conversations, tokenizes them, builds shifted labels
|
||||
with user turns masked so the model only learns to generate assistant responses.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len=2048, split="train_sft", cache_dir=None, max_samples=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
|
||||
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
|
||||
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
|
||||
|
||||
print(f"[SFT Data] Loading UltraChat 200K ({split})...")
|
||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, cache_dir=cache_dir)
|
||||
if max_samples:
|
||||
ds = ds.select(range(min(max_samples, len(ds))))
|
||||
print(f"[SFT Data] {len(ds)} conversations loaded")
|
||||
|
||||
self.examples = []
|
||||
skipped = 0
|
||||
for i, row in enumerate(ds):
|
||||
messages = row["messages"]
|
||||
if len(messages) < 2:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
text = format_conversation(messages)
|
||||
all_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Need at least max_seq_len+1 for shift, but truncate if longer
|
||||
if len(all_ids) > max_seq_len + 1:
|
||||
all_ids = all_ids[:max_seq_len + 1]
|
||||
|
||||
if len(all_ids) < 10:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Shifted: input = all_ids[:-1], target = all_ids[1:]
|
||||
input_ids = all_ids[:-1]
|
||||
target_ids = all_ids[1:]
|
||||
|
||||
# Build mask: -100 for user turns, real token id for assistant turns
|
||||
labels = self._build_shifted_labels(input_ids, target_ids)
|
||||
self.examples.append((input_ids, labels))
|
||||
|
||||
if (i + 1) % 50000 == 0:
|
||||
print(f" Processed {i+1} conversations...")
|
||||
|
||||
print(f"[SFT Data] {len(self.examples)} examples ready, {skipped} skipped")
|
||||
|
||||
def _build_shifted_labels(self, input_ids, target_ids):
|
||||
"""
|
||||
Walk through the token sequence and track whether we're in a user turn
|
||||
or assistant turn. Only keep labels for assistant response content.
|
||||
|
||||
Masking strategy (applied to the SHIFTED target):
|
||||
- Everything before and including <|assistant|>\\n: masked
|
||||
- Assistant response content and <|end|>: TRAIN
|
||||
- <|user|> and user content until next <|assistant|>: masked
|
||||
"""
|
||||
labels = [-100] * len(target_ids)
|
||||
in_assistant = False
|
||||
|
||||
for i, tid in enumerate(input_ids):
|
||||
if tid == self.assistant_token_id:
|
||||
# Next token after <|assistant|> is \n, then content starts
|
||||
in_assistant = True
|
||||
continue
|
||||
|
||||
if tid == self.user_token_id:
|
||||
in_assistant = False
|
||||
continue
|
||||
|
||||
if in_assistant:
|
||||
labels[i] = target_ids[i]
|
||||
|
||||
# When we hit <|end|> in assistant mode, include it then switch off
|
||||
if tid == self.end_token_id and in_assistant:
|
||||
in_assistant = False
|
||||
|
||||
return labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids, labels = self.examples[idx]
|
||||
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
|
||||
def sft_collate_fn(batch, pad_id=0):
|
||||
"""Pad sequences to the same length within a batch."""
|
||||
input_ids_list, labels_list = zip(*batch)
|
||||
max_len = max(ids.size(0) for ids in input_ids_list)
|
||||
|
||||
padded_inputs = []
|
||||
padded_labels = []
|
||||
for ids, lbl in zip(input_ids_list, labels_list):
|
||||
pad_len = max_len - ids.size(0)
|
||||
padded_inputs.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
|
||||
padded_labels.append(torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]))
|
||||
|
||||
return torch.stack(padded_inputs), torch.stack(padded_labels)
|
||||
|
||||
|
||||
def create_sft_dataloader(tokenizer, batch_size=4, max_seq_len=2048,
|
||||
cache_dir=None, max_samples=None, num_workers=4):
|
||||
dataset = SFTDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
split="train_sft",
|
||||
cache_dir=cache_dir,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
), dataset
|
||||
163
training_code/model/transformer.py
Normal file
163
training_code/model/transformer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
1B Parameter Decoder-Only Transformer — built from scratch.
|
||||
|
||||
Techniques:
|
||||
- RoPE (Rotary Position Embeddings)
|
||||
- Grouped Query Attention (GQA)
|
||||
- SwiGLU Feed-Forward
|
||||
- RMSNorm (pre-norm architecture)
|
||||
- Flash Attention 2 (via PyTorch SDPA)
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .config import ModelConfig
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
||||
return (x.float() * norm).type_as(x) * self.weight
|
||||
|
||||
|
||||
def precompute_rope_freqs(dim: int, max_seq_len: int, theta: float = 10000.0) -> torch.Tensor:
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(max_seq_len, dtype=torch.float32)
|
||||
freqs = torch.outer(t, freqs)
|
||||
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
B, S, H, D = xq.shape
|
||||
xq_c = torch.view_as_complex(xq.float().reshape(B, S, H, D // 2, 2))
|
||||
xk_c = torch.view_as_complex(xk.float().reshape(B, S, xk.shape[2], D // 2, 2))
|
||||
freqs = freqs_cis[:S].clone().unsqueeze(0).unsqueeze(2)
|
||||
xq_out = torch.view_as_real(xq_c * freqs).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_c * freqs).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.wq = nn.Linear(config.hidden_dim, self.num_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(self.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
B, S, _ = x.shape
|
||||
|
||||
q = self.wq(x).view(B, S, self.num_heads, self.head_dim)
|
||||
k = self.wk(x).view(B, S, self.num_kv_heads, self.head_dim)
|
||||
v = self.wv(x).view(B, S, self.num_kv_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_groups > 1:
|
||||
k = k.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
|
||||
k = k.reshape(B, S, self.num_heads, self.head_dim)
|
||||
v = v.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
|
||||
v = v.reshape(B, S, self.num_heads, self.head_dim)
|
||||
|
||||
# (B, num_heads, S, head_dim) for SDPA
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.w_gate = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
||||
self.w_up = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
||||
self.w_down = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.attention_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.attention = GroupedQueryAttention(config)
|
||||
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.ffn = SwiGLUFFN(config)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attention(self.attention_norm(x), freqs_cis)
|
||||
x = x + self.ffn(self.ffn_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
|
||||
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
|
||||
self.norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
||||
|
||||
# Pre-compute RoPE frequencies
|
||||
self.register_buffer(
|
||||
"freqs_cis",
|
||||
precompute_rope_freqs(config.head_dim, config.max_seq_len * 2, config.rope_theta),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize with scaled normal, following GPT-NeoX / LLaMA conventions."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
# Scale residual projections by 1/sqrt(2*num_layers)
|
||||
scale = (2 * self.config.num_layers) ** -0.5
|
||||
for layer in self.layers:
|
||||
nn.init.normal_(layer.attention.wo.weight, mean=0.0, std=0.02 * scale)
|
||||
nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
|
||||
B, S = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
freqs_cis = self.freqs_cis[:S]
|
||||
for layer in self.layers:
|
||||
h = layer(h, freqs_cis)
|
||||
h = self.norm(h)
|
||||
logits = self.output(h)
|
||||
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
targets.view(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return logits, loss
|
||||
148
training_code/test_sft.py
Normal file
148
training_code/test_sft.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Quick test of model quality with diverse prompts."""
|
||||
|
||||
import os, sys, time, torch
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
|
||||
SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
|
||||
CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
USER_START = "<|user|>\n"
|
||||
ASST_START = "<|assistant|>\n"
|
||||
TURN_END = "\n<|end|>\n"
|
||||
|
||||
TEST_PROMPTS = [
|
||||
"Hi! How are you?",
|
||||
"What is photosynthesis?",
|
||||
"Explain gravity to a 5-year-old.",
|
||||
"Write a short poem about the ocean.",
|
||||
"What are the three states of matter?",
|
||||
"How does a computer work?",
|
||||
"What is the capital of France and why is it famous?",
|
||||
"Give me 3 tips for learning a new language.",
|
||||
"What is machine learning in simple terms?",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model, tokenizer, prompt, max_new_tokens=256,
|
||||
temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15):
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE)
|
||||
generated = []
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False)
|
||||
end_id = end_token_ids[0] if end_token_ids else None
|
||||
user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False)
|
||||
user_id = user_token_ids[0] if user_token_ids else None
|
||||
|
||||
stop_ids = set()
|
||||
if eos_id is not None:
|
||||
stop_ids.add(eos_id)
|
||||
if end_id is not None:
|
||||
stop_ids.add(end_id)
|
||||
if user_id is not None:
|
||||
stop_ids.add(user_id)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :].float()
|
||||
|
||||
if repetition_penalty != 1.0 and generated:
|
||||
for tid in set(generated):
|
||||
if logits[0, tid] > 0:
|
||||
logits[0, tid] /= repetition_penalty
|
||||
else:
|
||||
logits[0, tid] *= repetition_penalty
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < topk_vals[:, -1:]] = float('-inf')
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p
|
||||
sorted_logits[remove] = float('-inf')
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, 1)
|
||||
token_id = next_token.item()
|
||||
|
||||
if token_id in stop_ids:
|
||||
break
|
||||
|
||||
generated.append(token_id)
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
if input_ids.size(1) > 2048:
|
||||
break
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True)
|
||||
|
||||
|
||||
def main():
|
||||
ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT"
|
||||
print("=" * 70)
|
||||
print(" " + ckpt_name + " MODEL TEST")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
config = ModelConfig()
|
||||
config.vocab_size = len(tokenizer)
|
||||
model = Transformer(config)
|
||||
|
||||
print("")
|
||||
print("Loading checkpoint: " + CHECKPOINT)
|
||||
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
step = ckpt.get("step", "?")
|
||||
del ckpt
|
||||
|
||||
model = model.to(DEVICE).bfloat16().eval()
|
||||
print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")")
|
||||
mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9
|
||||
print("GPU memory: " + str(round(mem, 1)) + " GB")
|
||||
print("-" * 70)
|
||||
|
||||
for i, question in enumerate(TEST_PROMPTS, 1):
|
||||
prompt = USER_START + question + TURN_END + ASST_START
|
||||
|
||||
print("")
|
||||
print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]")
|
||||
print(" Q: " + question)
|
||||
|
||||
t0 = time.time()
|
||||
response = generate(model, tokenizer, prompt)
|
||||
dt = time.time() - t0
|
||||
tokens = len(tokenizer.encode(response, add_special_tokens=False))
|
||||
|
||||
response = response.split("<|end|>")[0].split("<|user|>")[0].strip()
|
||||
|
||||
print(" A: " + response)
|
||||
tps = int(tokens / max(dt, 0.01))
|
||||
print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]")
|
||||
print("-" * 70)
|
||||
|
||||
print("")
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
257
training_code/train.py
Normal file
257
training_code/train.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Distributed training script for 1B parameter Transformer.
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train.py
|
||||
|
||||
Stack: PyTorch DDP + BF16 autocast + 8x H100 80GB
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig, TrainConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer, create_dataloader
|
||||
|
||||
|
||||
def get_wsd_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
"""Warmup-Stable-Decay: linear warmup -> constant -> cosine decay (last 20%)."""
|
||||
stable_end = int(total_steps * 0.8)
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
elif step < stable_end:
|
||||
return max_lr
|
||||
else:
|
||||
progress = (step - stable_end) / max(total_steps - stable_end, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def find_latest_checkpoint(checkpoint_dir):
|
||||
"""Find the latest step_*.pt checkpoint in the directory."""
|
||||
import glob
|
||||
pattern = os.path.join(checkpoint_dir, "step_*.pt")
|
||||
files = glob.glob(pattern)
|
||||
if not files:
|
||||
return None, 0
|
||||
latest = max(files, key=lambda f: int(os.path.basename(f).replace("step_", "").replace(".pt", "")))
|
||||
step = int(os.path.basename(latest).replace("step_", "").replace(".pt", ""))
|
||||
return latest, step
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
eff_batch = train_config.batch_size_per_gpu * world_size * train_config.gradient_accumulation_steps
|
||||
tokens_per_step = eff_batch * model_config.max_seq_len
|
||||
total_steps = train_config.total_tokens // tokens_per_step
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(train_config.log_dir, exist_ok=True)
|
||||
os.makedirs(train_config.checkpoint_dir, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(f" TRAINING 1B TRANSFORMER FROM SCRATCH")
|
||||
print(f" Arch: {model_config.num_layers}L / {model_config.hidden_dim}D / "
|
||||
f"{model_config.num_attention_heads}H / GQA-{model_config.num_kv_heads}KV / "
|
||||
f"SwiGLU-{model_config.intermediate_dim}")
|
||||
print(f" Seq: {model_config.max_seq_len} | Vocab: {model_config.vocab_size}")
|
||||
print(f" GPUs: {world_size}x H100 80GB | Backend: DDP + BF16 autocast")
|
||||
print(f" Batch: {eff_batch} seqs = {tokens_per_step:,} tok/step")
|
||||
print(f" Steps: {total_steps:,} | Target: {train_config.total_tokens:,} tokens")
|
||||
print("=" * 70)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Model
|
||||
torch.manual_seed(train_config.seed)
|
||||
model = Transformer(model_config).to(device)
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
print(f"[Init] Params: {n:,} ({n/1e9:.3f}B)")
|
||||
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
# Optimizer
|
||||
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": train_config.weight_decay},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), fused=True)
|
||||
|
||||
if rank == 0:
|
||||
dp = sum(p.numel() for p in decay_params)
|
||||
ndp = sum(p.numel() for p in nodecay_params)
|
||||
print(f"[Init] Optimizer: {dp:,} decay + {ndp:,} no-decay params")
|
||||
|
||||
# Resume from checkpoint
|
||||
resume_step = 0
|
||||
ckpt_path, ckpt_step = find_latest_checkpoint(train_config.checkpoint_dir)
|
||||
if ckpt_path is not None:
|
||||
if rank == 0:
|
||||
print(f"[Resume] Loading checkpoint: {ckpt_path} (step {ckpt_step})")
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
model.module.load_state_dict(ckpt["model"])
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
resume_step = ckpt["step"]
|
||||
if rank == 0:
|
||||
print(f"[Resume] Restored model + optimizer at step {resume_step}, "
|
||||
f"loss was {ckpt.get('loss', 'N/A')}")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
if rank == 0:
|
||||
print("[Init] No checkpoint found, starting from scratch")
|
||||
|
||||
# Data — use (seed + resume_step) so resumed runs see different shuffled data
|
||||
effective_seed = train_config.seed + resume_step
|
||||
dataloader = create_dataloader(tokenizer, train_config, rank=rank, world_size=world_size,
|
||||
seed_override=effective_seed)
|
||||
data_iter = iter(dataloader)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Dataloader ready (streaming FineWeb-Edu 10BT)")
|
||||
print(f"[Schedule] WSD: warmup {train_config.warmup_steps} -> "
|
||||
f"stable {int(total_steps*0.8)} -> decay {total_steps}")
|
||||
if resume_step > 0:
|
||||
remaining = total_steps - resume_step
|
||||
print(f"[Resume] Continuing from step {resume_step}, {remaining:,} steps remaining")
|
||||
print("-" * 70)
|
||||
sys.stdout.flush()
|
||||
|
||||
# ===== TRAINING LOOP =====
|
||||
model.train()
|
||||
global_step = resume_step
|
||||
running_loss = 0.0
|
||||
best_loss = float("inf")
|
||||
tokens_done = resume_step * tokens_per_step
|
||||
t0 = time.time()
|
||||
step_t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(train_config.log_dir, "train_log.jsonl"), "a") if rank == 0 else None
|
||||
|
||||
while global_step < total_steps:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
micro_loss = 0.0
|
||||
|
||||
for micro in range(train_config.gradient_accumulation_steps):
|
||||
try:
|
||||
input_ids, labels = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(dataloader)
|
||||
input_ids, labels = next(data_iter)
|
||||
|
||||
input_ids = input_ids.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True)
|
||||
|
||||
# BF16 autocast — no scaler needed (BF16 has enough dynamic range)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
_, loss = model(input_ids, labels)
|
||||
loss = loss / train_config.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
micro_loss += loss.item()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip)
|
||||
|
||||
# LR schedule
|
||||
lr = get_wsd_lr(global_step, train_config.warmup_steps, total_steps,
|
||||
train_config.learning_rate, train_config.min_lr)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += micro_loss
|
||||
tokens_done += tokens_per_step
|
||||
|
||||
# Log
|
||||
if global_step % train_config.log_interval == 0:
|
||||
dt = time.time() - step_t0
|
||||
tps = (train_config.log_interval * tokens_per_step) / max(dt, 1e-9)
|
||||
avg = running_loss / train_config.log_interval
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
print(
|
||||
f"[Step {global_step:>6d}/{total_steps}] "
|
||||
f"loss={avg:.4f} | lr={lr:.2e} | "
|
||||
f"tok/s={tps:,.0f} | GPU={gpu_mem:.1f}GB | "
|
||||
f"{pct:.1f}% | ETA={eta/3600:.1f}h",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "loss": round(avg, 4), "lr": lr,
|
||||
"tps": round(tps), "tokens": tokens_done,
|
||||
"gpu_gb": round(gpu_mem, 1), "elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
if avg < best_loss:
|
||||
best_loss = avg
|
||||
running_loss = 0.0
|
||||
step_t0 = time.time()
|
||||
|
||||
# Checkpoint
|
||||
if global_step % train_config.save_interval == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
ckpt_path = os.path.join(train_config.checkpoint_dir, f"step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"loss": avg if global_step % train_config.log_interval == 0 else micro_loss,
|
||||
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
|
||||
}, ckpt_path)
|
||||
print(f" >> Checkpoint: {ckpt_path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(train_config.checkpoint_dir, "final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" TRAINING COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Tokens: {tokens_done:,}")
|
||||
print(f" Time: {total_time/3600:.2f}h | Throughput: {tokens_done/total_time:,.0f} tok/s")
|
||||
print(f" Best loss: {best_loss:.4f}")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
327
training_code/train_dpo.py
Normal file
327
training_code/train_dpo.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
DPO (Direct Preference Optimization) training for the 1B Transformer.
|
||||
|
||||
Takes the SFT model and aligns it with human preferences using
|
||||
UltraFeedback preference pairs.
|
||||
|
||||
DPO Loss:
|
||||
L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x)))
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train_dpo.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
from model.dpo_data import DPODataset, dpo_collate_fn
|
||||
|
||||
|
||||
# === Config ===
|
||||
SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
|
||||
DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo"
|
||||
LOG_DIR = "/home/jovyan/training/logs"
|
||||
DATA_CACHE = "/jfs/deepak-kumar/data"
|
||||
|
||||
NUM_EPOCHS = 1
|
||||
BATCH_SIZE_PER_GPU = 2
|
||||
GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64
|
||||
MAX_SEQ_LEN = 1024
|
||||
LEARNING_RATE = 5e-7 # very low LR for DPO
|
||||
MIN_LR = 1e-7
|
||||
WARMUP_STEPS = 100
|
||||
WEIGHT_DECAY = 0.01
|
||||
GRAD_CLIP = 1.0
|
||||
BETA = 0.1 # DPO temperature
|
||||
LOG_INTERVAL = 10
|
||||
SAVE_INTERVAL = 200
|
||||
|
||||
|
||||
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def get_per_token_logps(model, input_ids, prompt_lens):
|
||||
"""
|
||||
Compute sum of log probabilities for response tokens only.
|
||||
input_ids: [B, S] full sequence (prompt + response)
|
||||
prompt_lens: [B] where response starts
|
||||
Returns: [B] sum of log probs over response tokens
|
||||
"""
|
||||
# Clone input to avoid inplace issues with shared RoPE buffers
|
||||
inp = input_ids[:, :-1].contiguous()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(inp)
|
||||
|
||||
labels = input_ids[:, 1:].contiguous()
|
||||
log_probs = F.log_softmax(logits.float(), dim=-1)
|
||||
token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
B, S = token_logps.shape
|
||||
mask = torch.zeros_like(token_logps)
|
||||
for b in range(B):
|
||||
pl = prompt_lens[b].item()
|
||||
response_start = max(0, pl - 1)
|
||||
seq_len = (labels[b] != 0).sum().item()
|
||||
mask[b, response_start:seq_len] = 1.0
|
||||
|
||||
return (token_logps * mask).sum(dim=1)
|
||||
|
||||
|
||||
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
|
||||
ref_chosen_logps, ref_rejected_logps, beta=0.1):
|
||||
"""Compute DPO loss and metrics."""
|
||||
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
|
||||
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
|
||||
|
||||
logits = chosen_rewards - rejected_rewards
|
||||
loss = -F.logsigmoid(logits).mean()
|
||||
|
||||
with torch.no_grad():
|
||||
chosen_better = (chosen_rewards > rejected_rewards).float().mean()
|
||||
reward_margin = (chosen_rewards - rejected_rewards).mean()
|
||||
|
||||
return loss, chosen_better.item(), reward_margin.item()
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True)
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
model_config = ModelConfig()
|
||||
model_config.vocab_size = len(tokenizer)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}")
|
||||
|
||||
# Policy model (trainable)
|
||||
policy = Transformer(model_config)
|
||||
ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
policy.load_state_dict(ckpt["model"])
|
||||
sft_step = ckpt.get("step", 0)
|
||||
if rank == 0:
|
||||
print(f"[Init] SFT model loaded (step {sft_step})")
|
||||
|
||||
# Reference model (frozen copy)
|
||||
ref_model = Transformer(model_config)
|
||||
ref_model.load_state_dict(ckpt["model"])
|
||||
del ckpt
|
||||
|
||||
policy = policy.to(device)
|
||||
ref_model = ref_model.to(device).bfloat16()
|
||||
ref_model.eval()
|
||||
for p in ref_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
policy = DDP(policy, device_ids=[local_rank])
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in policy.parameters())
|
||||
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
|
||||
print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}")
|
||||
|
||||
# Dataset
|
||||
dataset = DPODataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
split="train",
|
||||
cache_dir=DATA_CACHE,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE_PER_GPU,
|
||||
sampler=sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
|
||||
if rank == 0:
|
||||
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
|
||||
print(f"[Init] Dataset: {len(dataset):,} preference pairs")
|
||||
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
|
||||
print(f"[Init] Total steps: {total_steps}")
|
||||
print("-" * 70)
|
||||
|
||||
decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
|
||||
|
||||
policy.train()
|
||||
global_step = 0
|
||||
running_loss = 0.0
|
||||
running_acc = 0.0
|
||||
running_margin = 0.0
|
||||
t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
sampler.set_epoch(epoch)
|
||||
data_iter = iter(dataloader)
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
|
||||
|
||||
while True:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
batch_loss = 0.0
|
||||
batch_acc = 0.0
|
||||
batch_margin = 0.0
|
||||
valid_micros = 0
|
||||
|
||||
for _ in range(GRADIENT_ACCUMULATION):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
chosen_ids = batch["chosen_ids"].to(device, non_blocking=True)
|
||||
rejected_ids = batch["rejected_ids"].to(device, non_blocking=True)
|
||||
prompt_lens = batch["prompt_lens"].to(device, non_blocking=True)
|
||||
|
||||
policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens)
|
||||
policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens)
|
||||
ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens)
|
||||
|
||||
loss, acc, margin = dpo_loss(
|
||||
policy_chosen_logps, policy_rejected_logps,
|
||||
ref_chosen_logps, ref_rejected_logps,
|
||||
beta=BETA,
|
||||
)
|
||||
loss = loss / GRADIENT_ACCUMULATION
|
||||
loss.backward()
|
||||
|
||||
batch_loss += loss.item()
|
||||
batch_acc += acc
|
||||
batch_margin += margin
|
||||
valid_micros += 1
|
||||
|
||||
if valid_micros == 0:
|
||||
break
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP)
|
||||
|
||||
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += batch_loss
|
||||
running_acc += batch_acc / valid_micros
|
||||
running_margin += batch_margin / valid_micros
|
||||
|
||||
if global_step % LOG_INTERVAL == 0:
|
||||
avg_loss = running_loss / LOG_INTERVAL
|
||||
avg_acc = running_acc / LOG_INTERVAL
|
||||
avg_margin = running_margin / LOG_INTERVAL
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
print(
|
||||
f" [Step {global_step:>5d}/{total_steps}] "
|
||||
f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | "
|
||||
f"margin={avg_margin:.3f} | lr={lr:.2e} | "
|
||||
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "loss": round(avg_loss, 4),
|
||||
"accuracy": round(avg_acc, 4),
|
||||
"reward_margin": round(avg_margin, 4),
|
||||
"lr": lr, "elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
running_loss = 0.0
|
||||
running_acc = 0.0
|
||||
running_margin = 0.0
|
||||
|
||||
if global_step % SAVE_INTERVAL == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": policy.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": model_config.vocab_size,
|
||||
}, path)
|
||||
print(f" >> Checkpoint: {path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final save
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": policy.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": model_config.vocab_size,
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" DPO COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
|
||||
print(f" Time: {total_time/60:.1f} minutes")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
272
training_code/train_sft.py
Normal file
272
training_code/train_sft.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
SFT (Supervised Fine-Tuning) script for the 1B Transformer.
|
||||
|
||||
Takes the pretrained base model and fine-tunes it on instruction-response
|
||||
conversations from UltraChat 200K.
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train_sft.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
from model.sft_data import SFTDataset, sft_collate_fn
|
||||
|
||||
|
||||
# === Config ===
|
||||
BASE_CHECKPOINT = "/jfs/deepak-kumar/checkpoints/step_19000.pt"
|
||||
SFT_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_sft"
|
||||
LOG_DIR = "/home/jovyan/training/logs"
|
||||
DATA_CACHE = "/jfs/deepak-kumar/data"
|
||||
|
||||
NUM_EPOCHS = 2
|
||||
BATCH_SIZE_PER_GPU = 4
|
||||
GRADIENT_ACCUMULATION = 4 # effective batch = 4 * 8 * 4 = 128
|
||||
MAX_SEQ_LEN = 2048
|
||||
LEARNING_RATE = 2e-5 # much lower than pretraining — we're fine-tuning
|
||||
MIN_LR = 2e-6
|
||||
WARMUP_STEPS = 200
|
||||
WEIGHT_DECAY = 0.01
|
||||
GRAD_CLIP = 1.0
|
||||
LOG_INTERVAL = 10
|
||||
SAVE_INTERVAL = 500
|
||||
|
||||
|
||||
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(SFT_CHECKPOINT_DIR, exist_ok=True)
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(" SFT: INSTRUCTION FINE-TUNING 1B TRANSFORMER")
|
||||
print("=" * 70)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Load base model
|
||||
model_config = ModelConfig()
|
||||
torch.manual_seed(42)
|
||||
model = Transformer(model_config)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Loading base model from {BASE_CHECKPOINT}")
|
||||
ckpt = torch.load(BASE_CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
base_step = ckpt.get("step", 0)
|
||||
base_loss = ckpt.get("loss", "?")
|
||||
if rank == 0:
|
||||
print(f"[Init] Base model: step={base_step}, pretrain_loss={base_loss}")
|
||||
del ckpt
|
||||
|
||||
# Add chat tokens to embedding — expand vocab if needed
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
new_vocab_size = len(tokenizer)
|
||||
if new_vocab_size > model_config.vocab_size:
|
||||
if rank == 0:
|
||||
print(f"[Init] Expanding vocab: {model_config.vocab_size} -> {new_vocab_size}")
|
||||
|
||||
old_emb_weight = model.tok_embeddings.weight.data
|
||||
model.tok_embeddings = torch.nn.Embedding(new_vocab_size, model_config.hidden_dim)
|
||||
model.tok_embeddings.weight.data[:model_config.vocab_size] = old_emb_weight
|
||||
# Init new token embeddings as mean of existing (better than random)
|
||||
mean_emb = old_emb_weight.mean(dim=0)
|
||||
for i in range(model_config.vocab_size, new_vocab_size):
|
||||
model.tok_embeddings.weight.data[i] = mean_emb
|
||||
|
||||
old_output_weight = model.output.weight.data
|
||||
model.output = torch.nn.Linear(model_config.hidden_dim, new_vocab_size, bias=False)
|
||||
model.output.weight.data[:model_config.vocab_size] = old_output_weight
|
||||
|
||||
model.config.vocab_size = new_vocab_size
|
||||
|
||||
model = model.to(device)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
|
||||
|
||||
# Dataset (only load on each process)
|
||||
dataset = SFTDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
split="train_sft",
|
||||
cache_dir=DATA_CACHE,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE_PER_GPU,
|
||||
sampler=sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
|
||||
if rank == 0:
|
||||
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
|
||||
print(f"[Init] Dataset: {len(dataset):,} examples")
|
||||
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
|
||||
print(f"[Init] Total steps: {total_steps} | Epochs: {NUM_EPOCHS}")
|
||||
print(f"[Init] LR: {LEARNING_RATE} → {MIN_LR} (cosine)")
|
||||
print("-" * 70)
|
||||
|
||||
# Optimizer — lower LR for fine-tuning
|
||||
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
|
||||
|
||||
# Training
|
||||
model.train()
|
||||
global_step = 0
|
||||
running_loss = 0.0
|
||||
t0 = time.time()
|
||||
step_t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(LOG_DIR, "sft_log.jsonl"), "w") if rank == 0 else None
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
sampler.set_epoch(epoch)
|
||||
data_iter = iter(dataloader)
|
||||
micro_step = 0
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
|
||||
|
||||
while True:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
batch_loss = 0.0
|
||||
|
||||
for _ in range(GRADIENT_ACCUMULATION):
|
||||
try:
|
||||
input_ids, labels = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
input_ids = input_ids.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True)
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
_, loss = model(input_ids, labels)
|
||||
loss = loss / GRADIENT_ACCUMULATION
|
||||
|
||||
loss.backward()
|
||||
batch_loss += loss.item()
|
||||
micro_step += 1
|
||||
|
||||
if batch_loss == 0:
|
||||
break
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
|
||||
|
||||
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += batch_loss
|
||||
|
||||
if global_step % LOG_INTERVAL == 0:
|
||||
dt = time.time() - step_t0
|
||||
avg = running_loss / LOG_INTERVAL
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
print(
|
||||
f" [Step {global_step:>5d}/{total_steps}] "
|
||||
f"loss={avg:.4f} | lr={lr:.2e} | "
|
||||
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "epoch": epoch + 1,
|
||||
"loss": round(avg, 4), "lr": lr,
|
||||
"elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
running_loss = 0.0
|
||||
step_t0 = time.time()
|
||||
|
||||
if global_step % SAVE_INTERVAL == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
path = os.path.join(SFT_CHECKPOINT_DIR, f"sft_step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": new_vocab_size,
|
||||
}, path)
|
||||
print(f" >> Checkpoint: {path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final save
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(SFT_CHECKPOINT_DIR, "sft_final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": new_vocab_size,
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" SFT COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
|
||||
print(f" Time: {total_time/60:.1f} minutes")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user