初始化项目,由ModelHub XC社区提供模型
Model: divakar-yadav/transformer-1b-chat Source: Original Platform
This commit is contained in:
36
.gitattributes
vendored
Normal file
36
.gitattributes
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
training_report.pdf filter=lfs diff=lfs merge=lfs -text
|
||||
147
README.md
Normal file
147
README.md
Normal file
@@ -0,0 +1,147 @@
|
||||
---
|
||||
language:
|
||||
- en
|
||||
license: apache-2.0
|
||||
tags:
|
||||
- llama
|
||||
- causal-lm
|
||||
- from-scratch
|
||||
- dpo
|
||||
- chat
|
||||
- text-generation
|
||||
library_name: transformers
|
||||
pipeline_tag: text-generation
|
||||
model-index:
|
||||
- name: Transformer-1B-Chat
|
||||
results: []
|
||||
---
|
||||
|
||||
# Transformer-1B-Chat
|
||||
|
||||
A **1.1 billion parameter** decoder-only language model trained **entirely from scratch** -- pretraining, supervised fine-tuning, and preference alignment -- on 8x NVIDIA H100 GPUs.
|
||||
|
||||
## Model Details
|
||||
|
||||
| Property | Value |
|
||||
|---|---|
|
||||
| Parameters | 1,105,827,840 (1.1B) |
|
||||
| Architecture | LLaMA-style Decoder-only Transformer |
|
||||
| Hidden Size | 2048 |
|
||||
| Intermediate Size | 5504 (SwiGLU) |
|
||||
| Layers | 22 |
|
||||
| Attention Heads | 32 (Grouped Query Attention) |
|
||||
| KV Heads | 8 |
|
||||
| Head Dim | 64 |
|
||||
| Max Sequence Length | 2048 |
|
||||
| Vocab Size | 32,003 |
|
||||
| Precision | BFloat16 |
|
||||
|
||||
### Architecture Highlights
|
||||
|
||||
- **RoPE** (Rotary Position Embeddings) with theta=10,000
|
||||
- **Grouped Query Attention** (GQA) -- 4:1 query-to-KV head ratio for efficient inference
|
||||
- **SwiGLU** Feed-Forward Network
|
||||
- **RMSNorm** in a pre-norm configuration
|
||||
- **Flash Attention 2** via PyTorch SDPA
|
||||
|
||||
## Training Pipeline
|
||||
|
||||
This model was built through a complete 3-stage training pipeline:
|
||||
|
||||
### Stage 1: Pretraining
|
||||
|
||||
| Detail | Value |
|
||||
|---|---|
|
||||
| Dataset | HuggingFaceFW/fineweb-edu (sample-10BT) |
|
||||
| Tokens Trained | ~20B tokens |
|
||||
| Steps | 19,070 |
|
||||
| Duration | ~12.3 hours |
|
||||
| Optimizer | AdamW (lr=3e-4, betas=0.9/0.95, wd=0.1) |
|
||||
| Schedule | WSD (Warmup-Stable-Decay), warmup=1000 steps |
|
||||
| Batch Size | 512 sequences (8 GPUs x 8 micro x 8 grad accum) |
|
||||
| Final Loss | 2.43 |
|
||||
| Throughput | ~338K tokens/sec |
|
||||
|
||||
### Stage 2: Supervised Fine-Tuning (SFT)
|
||||
|
||||
| Detail | Value |
|
||||
|---|---|
|
||||
| Dataset | HuggingFaceH4/ultrachat_200k (207,865 conversations) |
|
||||
| Steps | 3,240 (2 epochs) |
|
||||
| Duration | ~52 minutes |
|
||||
| Optimizer | AdamW (lr=2e-5, cosine decay) |
|
||||
| Batch Size | 256 sequences |
|
||||
| Final Loss | 1.20 |
|
||||
|
||||
### Stage 3: Direct Preference Optimization (DPO)
|
||||
|
||||
| Detail | Value |
|
||||
|---|---|
|
||||
| Dataset | argilla/ultrafeedback-binarized-preferences-cleaned (60,917 pairs) |
|
||||
| Steps | 952 (1 epoch) |
|
||||
| Duration | ~14 minutes |
|
||||
| Optimizer | AdamW (lr=5e-7, cosine decay) |
|
||||
| Beta | 0.1 |
|
||||
| Batch Size | 64 pairs |
|
||||
| Final Loss | 0.49 |
|
||||
| Final Accuracy | 72.5% (chosen preferred over rejected) |
|
||||
| Final Reward Margin | 0.84 |
|
||||
|
||||
### Hardware
|
||||
|
||||
- **8x NVIDIA H100 80GB HBM3**
|
||||
- **Distributed Strategy**: PyTorch DDP (DistributedDataParallel)
|
||||
- **Communication**: NCCL
|
||||
- **Mixed Precision**: BF16 autocast
|
||||
- **Total Training Time**: ~13.5 hours (all 3 stages)
|
||||
|
||||
## Chat Template
|
||||
|
||||
The model uses a simple chat template with special tokens:
|
||||
|
||||
```
|
||||
<|user|>
|
||||
Your message here
|
||||
<|end|>
|
||||
<|assistant|>
|
||||
Model response here
|
||||
<|end|>
|
||||
```
|
||||
|
||||
### Special Tokens
|
||||
|
||||
| Token | ID | Purpose |
|
||||
|---|---|---|
|
||||
| `<|user|>` | 32000 | Start of user turn |
|
||||
| `<|assistant|>` | 32001 | Start of assistant turn |
|
||||
| `<|end|>` | 32002 | End of turn |
|
||||
|
||||
## Limitations
|
||||
|
||||
- **1.1B parameters** -- smaller models have inherent limitations in reasoning depth and factual accuracy
|
||||
- Trained on English data only
|
||||
- May generate plausible-sounding but incorrect information
|
||||
- The DPO alignment is single-epoch; additional iterations could improve quality
|
||||
- Not safety-tuned beyond what the UltraFeedback dataset provides
|
||||
|
||||
## Training Code
|
||||
|
||||
The full training code is open-sourced alongside this model.
|
||||
|
||||
```
|
||||
model/
|
||||
config.py # Model and training hyperparameters
|
||||
transformer.py # Full transformer implementation from scratch
|
||||
data.py # Pretraining data pipeline (FineWeb-Edu)
|
||||
sft_data.py # SFT data pipeline (UltraChat)
|
||||
dpo_data.py # DPO data pipeline (UltraFeedback)
|
||||
train.py # Pretraining script (DDP, 8-GPU)
|
||||
train_sft.py # SFT script
|
||||
train_dpo.py # DPO script
|
||||
chat.py # Interactive chat interface
|
||||
export_to_hf.py # Export to HuggingFace format
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
Apache 2.0
|
||||
24
config.json
Normal file
24
config.json
Normal file
@@ -0,0 +1,24 @@
|
||||
{
|
||||
"architectures": [
|
||||
"LlamaForCausalLM"
|
||||
],
|
||||
"model_type": "llama",
|
||||
"vocab_size": 32003,
|
||||
"hidden_size": 2048,
|
||||
"intermediate_size": 5504,
|
||||
"num_hidden_layers": 22,
|
||||
"num_attention_heads": 32,
|
||||
"num_key_value_heads": 8,
|
||||
"max_position_embeddings": 2048,
|
||||
"rope_theta": 10000.0,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"hidden_act": "silu",
|
||||
"initializer_range": 0.02,
|
||||
"tie_word_embeddings": false,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": true,
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 2
|
||||
}
|
||||
12
generation_config.json
Normal file
12
generation_config.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"bos_token_id": 1,
|
||||
"eos_token_id": 2,
|
||||
"pad_token_id": 2,
|
||||
"do_sample": true,
|
||||
"temperature": 0.7,
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.15,
|
||||
"max_new_tokens": 512,
|
||||
"transformers_version": "4.40.0"
|
||||
}
|
||||
3
model.safetensors
Normal file
3
model.safetensors
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b3019bbfbef68dd70bf94f15f603994fd134cdbcb79d61b1cbbfbbfaaf88083d
|
||||
size 2211678744
|
||||
11
special_tokens_map.json
Normal file
11
special_tokens_map.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "</s>",
|
||||
"unk_token": "<unk>",
|
||||
"additional_special_tokens": [
|
||||
"<|user|>",
|
||||
"<|assistant|>",
|
||||
"<|end|>"
|
||||
]
|
||||
}
|
||||
268080
tokenizer.json
Normal file
268080
tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
44
tokenizer_config.json
Normal file
44
tokenizer_config.json
Normal file
@@ -0,0 +1,44 @@
|
||||
{
|
||||
"add_prefix_space": null,
|
||||
"backend": "tokenizers",
|
||||
"bos_token": "<s>",
|
||||
"clean_up_tokenization_spaces": false,
|
||||
"eos_token": "</s>",
|
||||
"extra_special_tokens": [],
|
||||
"is_local": false,
|
||||
"legacy": false,
|
||||
"model_max_length": 2048,
|
||||
"pad_token": "</s>",
|
||||
"sp_model_kwargs": {},
|
||||
"spaces_between_special_tokens": false,
|
||||
"tokenizer_class": "LlamaTokenizerFast",
|
||||
"unk_token": "<unk>",
|
||||
"use_default_system_prompt": false,
|
||||
"added_tokens_decoder": {
|
||||
"32000": {
|
||||
"content": "<|user|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32001": {
|
||||
"content": "<|assistant|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
},
|
||||
"32002": {
|
||||
"content": "<|end|>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
}
|
||||
},
|
||||
"chat_template": "{% for message in messages %}{% if message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n<|end|>\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n<|end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
|
||||
}
|
||||
318
training_code/chat.py
Normal file
318
training_code/chat.py
Normal file
@@ -0,0 +1,318 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Interactive chat with the 1B Transformer.
|
||||
Runs in an infinite conversation loop from the terminal.
|
||||
|
||||
Usage:
|
||||
python chat.py # auto-find latest checkpoint
|
||||
python chat.py /jfs/deepak-kumar/checkpoints/step_19000.pt # specific checkpoint
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import readline # enables arrow keys and history in input()
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
|
||||
def find_latest_checkpoint():
|
||||
"""Look for DPO > SFT > pretrained checkpoint."""
|
||||
dpo_dir = "/jfs/deepak-kumar/checkpoints_dpo"
|
||||
sft_dir = "/jfs/deepak-kumar/checkpoints_sft"
|
||||
pt_dir = "/jfs/deepak-kumar/checkpoints"
|
||||
|
||||
# Prefer DPO final
|
||||
dpo_final = os.path.join(dpo_dir, "dpo_final.pt")
|
||||
if os.path.exists(dpo_final):
|
||||
return dpo_final, True
|
||||
|
||||
dpo_files = glob.glob(os.path.join(dpo_dir, "dpo_step_*.pt"))
|
||||
if dpo_files:
|
||||
return max(dpo_files, key=lambda f: int(f.split("dpo_step_")[1].split(".")[0])), True
|
||||
|
||||
# Then SFT
|
||||
sft_final = os.path.join(sft_dir, "sft_final.pt")
|
||||
if os.path.exists(sft_final):
|
||||
return sft_final, True
|
||||
|
||||
sft_files = glob.glob(os.path.join(sft_dir, "sft_step_*.pt"))
|
||||
if sft_files:
|
||||
return max(sft_files, key=lambda f: int(f.split("sft_step_")[1].split(".")[0])), True
|
||||
|
||||
# Fall back to pretrained
|
||||
pt_files = glob.glob(os.path.join(pt_dir, "step_*.pt"))
|
||||
if pt_files:
|
||||
return max(pt_files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0])), False
|
||||
|
||||
return None, False
|
||||
|
||||
|
||||
def load_model(checkpoint_path, tokenizer, device="cuda:0"):
|
||||
config = ModelConfig()
|
||||
model = Transformer(config)
|
||||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
|
||||
# Handle expanded vocab from SFT
|
||||
saved_vocab = ckpt.get("vocab_size", config.vocab_size)
|
||||
if saved_vocab > config.vocab_size:
|
||||
config.vocab_size = saved_vocab
|
||||
model = Transformer(config)
|
||||
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model = model.to(device).bfloat16().eval()
|
||||
step = ckpt.get("step", "?")
|
||||
loss = ckpt.get("loss", "?")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
return model, config, step, loss
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_stream(model, tokenizer, prompt, max_new_tokens=512,
|
||||
temperature=0.8, top_k=50, top_p=0.9,
|
||||
repetition_penalty=1.15, device="cuda:0",
|
||||
stop_token_ids=None):
|
||||
"""Generate tokens one at a time, yielding each for streaming output."""
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
generated_ids = []
|
||||
prev_decoded_len = 0
|
||||
|
||||
if stop_token_ids is None:
|
||||
stop_token_ids = set()
|
||||
else:
|
||||
stop_token_ids = set(stop_token_ids)
|
||||
stop_token_ids.add(tokenizer.eos_token_id)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
if input_ids.shape[1] >= model.config.max_seq_len:
|
||||
break
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if repetition_penalty != 1.0 and generated_ids:
|
||||
prev_tokens = torch.tensor(generated_ids, device=device).unique()
|
||||
for token_id in prev_tokens:
|
||||
if logits[0, token_id] > 0:
|
||||
logits[0, token_id] /= repetition_penalty
|
||||
else:
|
||||
logits[0, token_id] *= repetition_penalty
|
||||
|
||||
logits = logits / temperature
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, top_k)
|
||||
logits[logits < topk_vals[:, -1:]] = float("-inf")
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
||||
sorted_logits[mask] = float("-inf")
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
token_id = next_token.item()
|
||||
|
||||
# Stop on any stop token (EOS, <|end|>, <|user|>)
|
||||
if token_id in stop_token_ids:
|
||||
break
|
||||
|
||||
generated_ids.append(token_id)
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
full_decoded = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
||||
new_text = full_decoded[prev_decoded_len:]
|
||||
prev_decoded_len = len(full_decoded)
|
||||
yield new_text
|
||||
|
||||
return
|
||||
|
||||
|
||||
def print_banner(step, loss, device):
|
||||
print("\033[1;36m") # cyan bold
|
||||
print("=" * 60)
|
||||
print(" 1B TRANSFORMER — Interactive Chat")
|
||||
print("=" * 60)
|
||||
print(f"\033[0m Checkpoint : step {step}")
|
||||
print(f" Loss : {loss}")
|
||||
print(f" Device : {device}")
|
||||
print(f" Parameters : 1.106B")
|
||||
print()
|
||||
print(" \033[90mCommands:\033[0m")
|
||||
print(" \033[33m/quit\033[0m — exit")
|
||||
print(" \033[33m/clear\033[0m — clear conversation context")
|
||||
print(" \033[33m/temp N\033[0m — set temperature (default 0.8)")
|
||||
print(" \033[33m/tokens N\033[0m — set max tokens (default 512)")
|
||||
print(" \033[33m/topp N\033[0m — set top-p (default 0.9)")
|
||||
print(" \033[33m/topk N\033[0m — set top-k (default 50)")
|
||||
print(" \033[33m/rep N\033[0m — set repetition penalty (default 1.15)")
|
||||
print()
|
||||
print("\033[90m" + "─" * 60 + "\033[0m")
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda:0"
|
||||
|
||||
is_sft = False
|
||||
if len(sys.argv) > 1:
|
||||
checkpoint = sys.argv[1]
|
||||
is_sft = "sft" in checkpoint.lower()
|
||||
else:
|
||||
result = find_latest_checkpoint()
|
||||
if result[0] is None:
|
||||
print("No checkpoint found!")
|
||||
sys.exit(1)
|
||||
checkpoint, is_sft = result
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Add chat tokens for SFT models
|
||||
if is_sft:
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
print(f"\n Loading model from {checkpoint}...")
|
||||
print(f" Mode: {'SFT (chat)' if is_sft else 'Base (completion)'}")
|
||||
model, config, step, loss = load_model(checkpoint, tokenizer, device)
|
||||
print(f" Model loaded!\n")
|
||||
|
||||
print_banner(step, loss, device)
|
||||
if is_sft:
|
||||
print(" \033[1;32mSFT mode: The model will respond as a chat assistant.\033[0m\n")
|
||||
|
||||
# Settings
|
||||
temperature = 0.7 if is_sft else 0.8
|
||||
max_tokens = 512
|
||||
top_p = 0.9
|
||||
top_k = 50
|
||||
rep_penalty = 1.15
|
||||
context = ""
|
||||
|
||||
# Chat template tokens for SFT
|
||||
USER_START = "<|user|>\n"
|
||||
ASST_START = "<|assistant|>\n"
|
||||
TURN_END = "\n<|end|>\n"
|
||||
|
||||
# Build stop token IDs for generation
|
||||
sft_stop_ids = []
|
||||
if is_sft:
|
||||
vocab = tokenizer.get_vocab()
|
||||
for tok_str in ["<|end|>", "<|user|>"]:
|
||||
if tok_str in vocab:
|
||||
sft_stop_ids.append(vocab[tok_str])
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("\n\033[1;32mYou:\033[0m ").strip()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print("\n\nGoodbye!")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
# Handle commands
|
||||
if user_input.startswith("/"):
|
||||
cmd = user_input.lower().split()
|
||||
if cmd[0] == "/quit":
|
||||
print("Goodbye!")
|
||||
break
|
||||
elif cmd[0] == "/clear":
|
||||
context = ""
|
||||
print("\033[90m [Context cleared]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/temp" and len(cmd) > 1:
|
||||
temperature = float(cmd[1])
|
||||
print(f"\033[90m [Temperature set to {temperature}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/tokens" and len(cmd) > 1:
|
||||
max_tokens = int(cmd[1])
|
||||
print(f"\033[90m [Max tokens set to {max_tokens}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/topp" and len(cmd) > 1:
|
||||
top_p = float(cmd[1])
|
||||
print(f"\033[90m [Top-p set to {top_p}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/topk" and len(cmd) > 1:
|
||||
top_k = int(cmd[1])
|
||||
print(f"\033[90m [Top-k set to {top_k}]\033[0m")
|
||||
continue
|
||||
elif cmd[0] == "/rep" and len(cmd) > 1:
|
||||
rep_penalty = float(cmd[1])
|
||||
print(f"\033[90m [Repetition penalty set to {rep_penalty}]\033[0m")
|
||||
continue
|
||||
else:
|
||||
print("\033[90m Unknown command. Try /quit, /clear, /temp, /tokens, /topp, /topk, /rep\033[0m")
|
||||
continue
|
||||
|
||||
# Build prompt
|
||||
if is_sft:
|
||||
prompt = context + USER_START + user_input + TURN_END + ASST_START
|
||||
else:
|
||||
if context:
|
||||
prompt = context + "\n" + user_input
|
||||
else:
|
||||
prompt = user_input
|
||||
|
||||
# Trim context if too long
|
||||
while len(tokenizer.encode(prompt)) > config.max_seq_len - max_tokens:
|
||||
if is_sft:
|
||||
parts = context.split(TURN_END)
|
||||
if len(parts) <= 2:
|
||||
break
|
||||
context = TURN_END.join(parts[2:])
|
||||
prompt = context + USER_START + user_input + TURN_END + ASST_START
|
||||
else:
|
||||
lines = prompt.split("\n")
|
||||
if len(lines) <= 2:
|
||||
break
|
||||
prompt = "\n".join(lines[1:])
|
||||
|
||||
# Generate with streaming
|
||||
print("\033[1;34mModel:\033[0m ", end="", flush=True)
|
||||
t0 = time.time()
|
||||
full_response = ""
|
||||
token_count = 0
|
||||
|
||||
for token_text in generate_stream(
|
||||
model, tokenizer, prompt,
|
||||
max_new_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=rep_penalty,
|
||||
device=device,
|
||||
stop_token_ids=sft_stop_ids if is_sft else None,
|
||||
):
|
||||
print(token_text, end="", flush=True)
|
||||
full_response += token_text
|
||||
token_count += 1
|
||||
|
||||
elapsed = time.time() - t0
|
||||
tps = token_count / max(elapsed, 1e-9)
|
||||
print(f"\n\033[90m [{token_count} tokens, {tps:.1f} tok/s, {elapsed:.1f}s]\033[0m")
|
||||
|
||||
# Append to context for multi-turn
|
||||
if is_sft:
|
||||
context = (context + USER_START + user_input + TURN_END +
|
||||
ASST_START + full_response.strip() + TURN_END)
|
||||
else:
|
||||
context = prompt + full_response
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
167
training_code/export_to_hf.py
Normal file
167
training_code/export_to_hf.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Export the trained model to HuggingFace-compatible format.
|
||||
|
||||
Creates:
|
||||
- model.safetensors (weights)
|
||||
- config.json (architecture config)
|
||||
- generation_config.json
|
||||
- tokenizer.json, tokenizer_config.json, special_tokens_map.json
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
import torch
|
||||
from collections import OrderedDict
|
||||
from safetensors.torch import save_file
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
CHECKPOINT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
|
||||
OUTPUT_DIR = "/home/jovyan/training/hf_model"
|
||||
|
||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||
|
||||
print("=" * 60)
|
||||
print(" EXPORTING MODEL TO HUGGING FACE FORMAT")
|
||||
print("=" * 60)
|
||||
|
||||
# --- 1. Load model ---
|
||||
print("\n[1/4] Loading checkpoint...")
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
model_config = ModelConfig()
|
||||
model_config.vocab_size = len(tokenizer)
|
||||
|
||||
model = Transformer(model_config)
|
||||
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
step = ckpt.get("step", 0)
|
||||
del ckpt
|
||||
print(f" Loaded DPO model (step {step}, vocab {model_config.vocab_size})")
|
||||
|
||||
# --- 2. Convert state dict keys to HF-style naming ---
|
||||
print("\n[2/4] Converting weights to safetensors...")
|
||||
|
||||
state_dict = model.state_dict()
|
||||
hf_state = OrderedDict()
|
||||
|
||||
KEY_MAP = {
|
||||
"tok_embeddings.weight": "model.embed_tokens.weight",
|
||||
"norm.weight": "model.norm.weight",
|
||||
"output.weight": "lm_head.weight",
|
||||
}
|
||||
|
||||
for key, tensor in state_dict.items():
|
||||
if key in KEY_MAP:
|
||||
hf_state[KEY_MAP[key]] = tensor
|
||||
continue
|
||||
|
||||
if key.startswith("layers."):
|
||||
parts = key.split(".")
|
||||
layer_idx = parts[1]
|
||||
rest = ".".join(parts[2:])
|
||||
|
||||
layer_map = {
|
||||
"attention_norm.weight": f"model.layers.{layer_idx}.input_layernorm.weight",
|
||||
"ffn_norm.weight": f"model.layers.{layer_idx}.post_attention_layernorm.weight",
|
||||
"attention.wq.weight": f"model.layers.{layer_idx}.self_attn.q_proj.weight",
|
||||
"attention.wk.weight": f"model.layers.{layer_idx}.self_attn.k_proj.weight",
|
||||
"attention.wv.weight": f"model.layers.{layer_idx}.self_attn.v_proj.weight",
|
||||
"attention.wo.weight": f"model.layers.{layer_idx}.self_attn.o_proj.weight",
|
||||
"ffn.w_gate.weight": f"model.layers.{layer_idx}.mlp.gate_proj.weight",
|
||||
"ffn.w_up.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight",
|
||||
"ffn.w_down.weight": f"model.layers.{layer_idx}.mlp.down_proj.weight",
|
||||
}
|
||||
|
||||
if rest in layer_map:
|
||||
hf_state[layer_map[rest]] = tensor
|
||||
else:
|
||||
print(f" WARNING: unmapped key {key}")
|
||||
hf_state[key] = tensor
|
||||
elif key == "freqs_cis":
|
||||
continue
|
||||
else:
|
||||
print(f" WARNING: unmapped key {key}")
|
||||
hf_state[key] = tensor
|
||||
|
||||
# Convert all to bfloat16 for storage
|
||||
for k in hf_state:
|
||||
if hf_state[k].dtype == torch.float32:
|
||||
hf_state[k] = hf_state[k].to(torch.bfloat16)
|
||||
|
||||
safetensors_path = os.path.join(OUTPUT_DIR, "model.safetensors")
|
||||
save_file(hf_state, safetensors_path)
|
||||
size_gb = os.path.getsize(safetensors_path) / 1e9
|
||||
print(f" Saved {len(hf_state)} tensors -> {safetensors_path} ({size_gb:.2f} GB)")
|
||||
|
||||
# --- 3. Write config files ---
|
||||
print("\n[3/4] Writing config files...")
|
||||
|
||||
config_json = {
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"model_type": "llama",
|
||||
"vocab_size": model_config.vocab_size,
|
||||
"hidden_size": model_config.hidden_dim,
|
||||
"intermediate_size": model_config.intermediate_dim,
|
||||
"num_hidden_layers": model_config.num_layers,
|
||||
"num_attention_heads": model_config.num_attention_heads,
|
||||
"num_key_value_heads": model_config.num_kv_heads,
|
||||
"max_position_embeddings": model_config.max_seq_len,
|
||||
"rope_theta": model_config.rope_theta,
|
||||
"rms_norm_eps": model_config.rms_norm_eps,
|
||||
"hidden_act": "silu",
|
||||
"initializer_range": 0.02,
|
||||
"tie_word_embeddings": False,
|
||||
"torch_dtype": "bfloat16",
|
||||
"transformers_version": "4.40.0",
|
||||
"use_cache": True,
|
||||
"bos_token_id": tokenizer.bos_token_id,
|
||||
"eos_token_id": tokenizer.eos_token_id,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
}
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR, "config.json"), "w") as f:
|
||||
json.dump(config_json, f, indent=2)
|
||||
print(" config.json")
|
||||
|
||||
gen_config = {
|
||||
"bos_token_id": tokenizer.bos_token_id,
|
||||
"eos_token_id": tokenizer.eos_token_id,
|
||||
"pad_token_id": tokenizer.pad_token_id,
|
||||
"do_sample": True,
|
||||
"temperature": 0.7,
|
||||
"top_k": 50,
|
||||
"top_p": 0.9,
|
||||
"repetition_penalty": 1.15,
|
||||
"max_new_tokens": 512,
|
||||
"transformers_version": "4.40.0",
|
||||
}
|
||||
|
||||
with open(os.path.join(OUTPUT_DIR, "generation_config.json"), "w") as f:
|
||||
json.dump(gen_config, f, indent=2)
|
||||
print(" generation_config.json")
|
||||
|
||||
# --- 4. Export tokenizer ---
|
||||
print("\n[4/4] Exporting tokenizer...")
|
||||
tokenizer.save_pretrained(OUTPUT_DIR)
|
||||
print(" Tokenizer files saved")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print(" EXPORT COMPLETE -> " + OUTPUT_DIR)
|
||||
print("=" * 60)
|
||||
print("\nFiles:")
|
||||
for f in sorted(os.listdir(OUTPUT_DIR)):
|
||||
size = os.path.getsize(os.path.join(OUTPUT_DIR, f))
|
||||
if size > 1e6:
|
||||
print(f" {f:40s} {size/1e6:.1f} MB")
|
||||
else:
|
||||
print(f" {f:40s} {size/1e3:.1f} KB")
|
||||
134
training_code/inference.py
Normal file
134
training_code/inference.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Inference script for the 1B Transformer — Single GPU.
|
||||
|
||||
Usage:
|
||||
python inference.py # auto-finds latest checkpoint
|
||||
python inference.py /path/to/checkpoint.pt # specific checkpoint
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import glob
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
|
||||
def find_latest_checkpoint(checkpoint_dir="/jfs/deepak-kumar/checkpoints"):
|
||||
files = glob.glob(os.path.join(checkpoint_dir, "step_*.pt"))
|
||||
if not files:
|
||||
final = os.path.join(checkpoint_dir, "final.pt")
|
||||
return final if os.path.exists(final) else None
|
||||
return max(files, key=lambda f: int(os.path.basename(f).split("_")[1].split(".")[0]))
|
||||
|
||||
|
||||
def load_model(checkpoint_path, device="cuda:0"):
|
||||
config = ModelConfig()
|
||||
model = Transformer(config)
|
||||
|
||||
print(f"Loading checkpoint: {checkpoint_path}")
|
||||
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
model = model.to(device).bfloat16().eval()
|
||||
|
||||
step = ckpt.get("step", "?")
|
||||
loss = ckpt.get("loss", "?")
|
||||
print(f" Step: {step} | Loss: {loss}")
|
||||
print(f" Params: {sum(p.numel() for p in model.parameters()):,}")
|
||||
print(f" Device: {device}")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
return model, config
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model, tokenizer, prompt, max_new_tokens=200,
|
||||
temperature=0.8, top_k=50, top_p=0.9, device="cuda:0"):
|
||||
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)
|
||||
t0 = time.time()
|
||||
|
||||
for i in range(max_new_tokens):
|
||||
if input_ids.shape[1] >= model.config.max_seq_len:
|
||||
break
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :] / temperature
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, top_k)
|
||||
logits[logits < topk_vals[:, -1:]] = float("-inf")
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
mask = cum_probs - F.softmax(sorted_logits, dim=-1) >= top_p
|
||||
sorted_logits[mask] = float("-inf")
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, num_samples=1)
|
||||
|
||||
if next_token.item() == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
elapsed = time.time() - t0
|
||||
gen_tokens = input_ids.shape[1] - len(tokenizer.encode(prompt))
|
||||
tok_per_sec = gen_tokens / max(elapsed, 1e-9)
|
||||
|
||||
text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
|
||||
return text, gen_tokens, tok_per_sec
|
||||
|
||||
|
||||
def main():
|
||||
device = "cuda:0"
|
||||
if len(sys.argv) > 1:
|
||||
checkpoint = sys.argv[1]
|
||||
else:
|
||||
checkpoint = find_latest_checkpoint()
|
||||
if checkpoint is None:
|
||||
print("No checkpoint found!")
|
||||
sys.exit(1)
|
||||
|
||||
model, config = load_model(checkpoint, device)
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
prompts = [
|
||||
"The meaning of life is",
|
||||
"In machine learning, a neural network",
|
||||
"The capital of France is",
|
||||
"Once upon a time, there was a",
|
||||
"To solve a quadratic equation, you need to",
|
||||
"The theory of relativity explains that",
|
||||
"Python is a programming language that",
|
||||
"The sun rises in the east and",
|
||||
]
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
print(" INFERENCE — 1B Transformer (Single GPU)")
|
||||
print("=" * 70)
|
||||
|
||||
for prompt in prompts:
|
||||
print(f"\n{'─' * 60}")
|
||||
print(f"PROMPT: {prompt}")
|
||||
print(f"{'─' * 60}")
|
||||
text, n_tok, tps = generate(model, tokenizer, prompt,
|
||||
max_new_tokens=150, temperature=0.8,
|
||||
top_k=50, device=device)
|
||||
generated = text[len(prompt):]
|
||||
print(f"OUTPUT:{generated}")
|
||||
print(f" [{n_tok} tokens, {tps:.1f} tok/s]")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
2
training_code/model/__init__.py
Normal file
2
training_code/model/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .config import ModelConfig, TrainConfig
|
||||
from .transformer import Transformer
|
||||
78
training_code/model/config.py
Normal file
78
training_code/model/config.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Configuration for 1B parameter LLaMA-style Transformer model.
|
||||
Architecture: Decoder-only Transformer with RoPE, GQA, SwiGLU, RMSNorm.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
vocab_size: int = 32000
|
||||
hidden_dim: int = 2048
|
||||
intermediate_dim: int = 5504 # ~2.7x hidden for SwiGLU (adjusted for param count)
|
||||
num_layers: int = 22
|
||||
num_attention_heads: int = 32
|
||||
num_kv_heads: int = 8 # GQA: 4 query heads per KV head
|
||||
max_seq_len: int = 2048
|
||||
rope_theta: float = 10000.0
|
||||
rms_norm_eps: float = 1e-5
|
||||
dropout: float = 0.0 # No dropout (modern practice for pretraining)
|
||||
tie_word_embeddings: bool = False
|
||||
|
||||
@property
|
||||
def head_dim(self) -> int:
|
||||
return self.hidden_dim // self.num_attention_heads
|
||||
|
||||
@property
|
||||
def num_params_approx(self) -> int:
|
||||
"""Rough parameter count estimate."""
|
||||
embed = self.vocab_size * self.hidden_dim
|
||||
attn_per_layer = (
|
||||
self.hidden_dim * self.head_dim * self.num_attention_heads + # Q
|
||||
self.hidden_dim * self.head_dim * self.num_kv_heads + # K
|
||||
self.hidden_dim * self.head_dim * self.num_kv_heads + # V
|
||||
self.head_dim * self.num_attention_heads * self.hidden_dim # O
|
||||
)
|
||||
ffn_per_layer = 3 * self.hidden_dim * self.intermediate_dim # gate + up + down
|
||||
norm_per_layer = 2 * self.hidden_dim
|
||||
total = (
|
||||
embed +
|
||||
self.num_layers * (attn_per_layer + ffn_per_layer + norm_per_layer) +
|
||||
self.hidden_dim + # final norm
|
||||
(0 if self.tie_word_embeddings else self.vocab_size * self.hidden_dim)
|
||||
)
|
||||
return total
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
# Paths
|
||||
checkpoint_dir: str = "/jfs/deepak-kumar/checkpoints"
|
||||
data_cache_dir: str = "/jfs/deepak-kumar/data"
|
||||
log_dir: str = "/home/jovyan/training/logs"
|
||||
|
||||
# Training
|
||||
total_tokens: int = 20_000_000_000 # 20B tokens
|
||||
batch_size_per_gpu: int = 8
|
||||
gradient_accumulation_steps: int = 8 # effective batch = 8 * 8 * 8 = 512 seqs
|
||||
max_seq_len: int = 2048
|
||||
|
||||
# WSD Schedule
|
||||
learning_rate: float = 3e-4
|
||||
min_lr: float = 3e-5
|
||||
warmup_steps: int = 1000
|
||||
weight_decay: float = 0.1
|
||||
beta1: float = 0.9
|
||||
beta2: float = 0.95
|
||||
grad_clip: float = 1.0
|
||||
|
||||
# Logging
|
||||
log_interval: int = 10
|
||||
save_interval: int = 1000
|
||||
eval_interval: int = 500
|
||||
|
||||
# System
|
||||
num_workers: int = 4
|
||||
seed: int = 42
|
||||
bf16: bool = True
|
||||
79
training_code/model/data.py
Normal file
79
training_code/model/data.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Data pipeline: streams and tokenizes OpenWebText for pretraining.
|
||||
Packs sequences to max_seq_len for efficiency (no padding waste).
|
||||
"""
|
||||
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_tokenizer(name: str = "mistralai/Mistral-7B-v0.1"):
|
||||
"""Use Mistral's tokenizer — 32k vocab, BPE, well-trained on diverse data."""
|
||||
tok = AutoTokenizer.from_pretrained(name, use_fast=True)
|
||||
if tok.pad_token is None:
|
||||
tok.pad_token = tok.eos_token
|
||||
return tok
|
||||
|
||||
|
||||
class PackedPretrainDataset(IterableDataset):
|
||||
"""
|
||||
Streams text from HuggingFace dataset, tokenizes on the fly,
|
||||
and packs into fixed-length sequences for maximum GPU utilization.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len: int, split: str = "train", cache_dir: str = None, seed: int = 42):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
self.split = split
|
||||
self.cache_dir = cache_dir
|
||||
self.seed = seed
|
||||
self.eos_id = tokenizer.eos_token_id
|
||||
|
||||
def _token_stream(self):
|
||||
ds = load_dataset(
|
||||
"HuggingFaceFW/fineweb-edu",
|
||||
name="sample-10BT",
|
||||
split=self.split,
|
||||
streaming=True,
|
||||
cache_dir=self.cache_dir,
|
||||
)
|
||||
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
|
||||
|
||||
for example in ds:
|
||||
text = example.get("text", "")
|
||||
if len(text.strip()) < 50:
|
||||
continue
|
||||
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
|
||||
yield from token_ids
|
||||
yield self.eos_id
|
||||
|
||||
def __iter__(self):
|
||||
buffer = []
|
||||
for token_id in self._token_stream():
|
||||
buffer.append(token_id)
|
||||
if len(buffer) == self.max_seq_len + 1:
|
||||
input_ids = torch.tensor(buffer[:-1], dtype=torch.long)
|
||||
labels = torch.tensor(buffer[1:], dtype=torch.long)
|
||||
yield input_ids, labels
|
||||
buffer = []
|
||||
|
||||
|
||||
def create_dataloader(tokenizer, config, rank: int = 0, world_size: int = 1, seed_override: int = None):
|
||||
seed = seed_override if seed_override is not None else config.seed
|
||||
dataset = PackedPretrainDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=config.max_seq_len,
|
||||
split="train",
|
||||
cache_dir=config.data_cache_dir,
|
||||
seed=seed + rank,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=config.batch_size_per_gpu,
|
||||
num_workers=config.num_workers,
|
||||
pin_memory=True,
|
||||
prefetch_factor=4,
|
||||
)
|
||||
144
training_code/model/dpo_data.py
Normal file
144
training_code/model/dpo_data.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""
|
||||
DPO data pipeline: loads UltraFeedback preference pairs.
|
||||
|
||||
Each example has a prompt + chosen response + rejected response.
|
||||
We tokenize both (prompt+chosen) and (prompt+rejected), apply the same
|
||||
chat template, and return them as pairs for DPO training.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
CHAT_TEMPLATE = {
|
||||
"user_start": "<|user|>\n",
|
||||
"assistant_start": "<|assistant|>\n",
|
||||
"turn_end": "\n<|end|>\n",
|
||||
}
|
||||
|
||||
|
||||
def format_preference_pair(prompt, chosen_msgs, rejected_msgs):
|
||||
"""Build chat-templated strings for chosen and rejected."""
|
||||
def build(messages):
|
||||
text = CHAT_TEMPLATE["user_start"] + prompt.strip() + CHAT_TEMPLATE["turn_end"]
|
||||
for msg in messages:
|
||||
role = msg.get("role", "assistant")
|
||||
content = msg.get("content", "").strip()
|
||||
if role == "assistant":
|
||||
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
elif role == "user":
|
||||
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
return text
|
||||
|
||||
return build(chosen_msgs), build(rejected_msgs)
|
||||
|
||||
|
||||
class DPODataset(Dataset):
|
||||
"""
|
||||
Loads UltraFeedback preference pairs and tokenizes them.
|
||||
Returns (prompt_ids, chosen_ids, rejected_ids) with proper shifting.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len=2048, split="train",
|
||||
cache_dir=None, max_samples=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
|
||||
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
|
||||
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
|
||||
|
||||
print(f"[DPO Data] Loading UltraFeedback preferences ({split})...")
|
||||
ds = load_dataset(
|
||||
"argilla/ultrafeedback-binarized-preferences-cleaned",
|
||||
split=split,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
if max_samples:
|
||||
ds = ds.select(range(min(max_samples, len(ds))))
|
||||
print(f"[DPO Data] {len(ds)} preference pairs loaded")
|
||||
|
||||
self.examples = []
|
||||
skipped = 0
|
||||
for i, row in enumerate(ds):
|
||||
prompt = row.get("prompt", "")
|
||||
chosen = row.get("chosen", [])
|
||||
rejected = row.get("rejected", [])
|
||||
|
||||
if not prompt or not chosen or not rejected:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
chosen_text, rejected_text = format_preference_pair(prompt, chosen, rejected)
|
||||
|
||||
chosen_ids = tokenizer.encode(chosen_text, add_special_tokens=False)
|
||||
rejected_ids = tokenizer.encode(rejected_text, add_special_tokens=False)
|
||||
|
||||
# Truncate if needed
|
||||
if len(chosen_ids) > max_seq_len + 1:
|
||||
chosen_ids = chosen_ids[:max_seq_len + 1]
|
||||
if len(rejected_ids) > max_seq_len + 1:
|
||||
rejected_ids = rejected_ids[:max_seq_len + 1]
|
||||
|
||||
if len(chosen_ids) < 10 or len(rejected_ids) < 10:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Find where the prompt ends (first <|assistant|> token)
|
||||
prompt_end = 0
|
||||
for j, tid in enumerate(chosen_ids):
|
||||
if tid == self.assistant_token_id:
|
||||
prompt_end = j + 2 # skip <|assistant|> and \n
|
||||
break
|
||||
|
||||
self.examples.append({
|
||||
"chosen_ids": chosen_ids,
|
||||
"rejected_ids": rejected_ids,
|
||||
"prompt_len": prompt_end,
|
||||
})
|
||||
|
||||
if (i + 1) % 20000 == 0:
|
||||
print(f" Processed {i+1} pairs...")
|
||||
|
||||
print(f"[DPO Data] {len(self.examples)} pairs ready, {skipped} skipped")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
ex = self.examples[idx]
|
||||
return {
|
||||
"chosen_ids": torch.tensor(ex["chosen_ids"], dtype=torch.long),
|
||||
"rejected_ids": torch.tensor(ex["rejected_ids"], dtype=torch.long),
|
||||
"prompt_len": ex["prompt_len"],
|
||||
}
|
||||
|
||||
|
||||
def dpo_collate_fn(batch, pad_id=0):
|
||||
"""Pad chosen and rejected sequences separately."""
|
||||
max_chosen = max(b["chosen_ids"].size(0) for b in batch)
|
||||
max_rejected = max(b["rejected_ids"].size(0) for b in batch)
|
||||
|
||||
chosen_padded = []
|
||||
rejected_padded = []
|
||||
prompt_lens = []
|
||||
|
||||
for b in batch:
|
||||
c_pad = max_chosen - b["chosen_ids"].size(0)
|
||||
r_pad = max_rejected - b["rejected_ids"].size(0)
|
||||
chosen_padded.append(torch.cat([b["chosen_ids"], torch.full((c_pad,), pad_id, dtype=torch.long)]))
|
||||
rejected_padded.append(torch.cat([b["rejected_ids"], torch.full((r_pad,), pad_id, dtype=torch.long)]))
|
||||
prompt_lens.append(b["prompt_len"])
|
||||
|
||||
return {
|
||||
"chosen_ids": torch.stack(chosen_padded),
|
||||
"rejected_ids": torch.stack(rejected_padded),
|
||||
"prompt_lens": torch.tensor(prompt_lens, dtype=torch.long),
|
||||
}
|
||||
169
training_code/model/sft_data.py
Normal file
169
training_code/model/sft_data.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
SFT data pipeline: loads UltraChat 200K and formats into chat template.
|
||||
|
||||
Chat template:
|
||||
<|user|>
|
||||
What is gravity?
|
||||
<|end|>
|
||||
<|assistant|>
|
||||
Gravity is a fundamental force...
|
||||
<|end|>
|
||||
|
||||
Labels are shifted left by 1 (standard causal LM), with user turns masked.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
CHAT_TEMPLATE = {
|
||||
"user_start": "<|user|>\n",
|
||||
"assistant_start": "<|assistant|>\n",
|
||||
"turn_end": "\n<|end|>\n",
|
||||
}
|
||||
|
||||
|
||||
def format_conversation(messages):
|
||||
"""Convert a list of {role, content} messages into our chat template string."""
|
||||
text = ""
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"].strip()
|
||||
if role == "user":
|
||||
text += CHAT_TEMPLATE["user_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
elif role == "assistant":
|
||||
text += CHAT_TEMPLATE["assistant_start"] + content + CHAT_TEMPLATE["turn_end"]
|
||||
return text
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Loads UltraChat 200K conversations, tokenizes them, builds shifted labels
|
||||
with user turns masked so the model only learns to generate assistant responses.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer, max_seq_len=2048, split="train_sft", cache_dir=None, max_samples=None):
|
||||
self.tokenizer = tokenizer
|
||||
self.max_seq_len = max_seq_len
|
||||
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
self.assistant_token_id = tokenizer.encode("<|assistant|>", add_special_tokens=False)[0]
|
||||
self.end_token_id = tokenizer.encode("<|end|>", add_special_tokens=False)[0]
|
||||
self.user_token_id = tokenizer.encode("<|user|>", add_special_tokens=False)[0]
|
||||
|
||||
print(f"[SFT Data] Loading UltraChat 200K ({split})...")
|
||||
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split=split, cache_dir=cache_dir)
|
||||
if max_samples:
|
||||
ds = ds.select(range(min(max_samples, len(ds))))
|
||||
print(f"[SFT Data] {len(ds)} conversations loaded")
|
||||
|
||||
self.examples = []
|
||||
skipped = 0
|
||||
for i, row in enumerate(ds):
|
||||
messages = row["messages"]
|
||||
if len(messages) < 2:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
text = format_conversation(messages)
|
||||
all_ids = tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
# Need at least max_seq_len+1 for shift, but truncate if longer
|
||||
if len(all_ids) > max_seq_len + 1:
|
||||
all_ids = all_ids[:max_seq_len + 1]
|
||||
|
||||
if len(all_ids) < 10:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Shifted: input = all_ids[:-1], target = all_ids[1:]
|
||||
input_ids = all_ids[:-1]
|
||||
target_ids = all_ids[1:]
|
||||
|
||||
# Build mask: -100 for user turns, real token id for assistant turns
|
||||
labels = self._build_shifted_labels(input_ids, target_ids)
|
||||
self.examples.append((input_ids, labels))
|
||||
|
||||
if (i + 1) % 50000 == 0:
|
||||
print(f" Processed {i+1} conversations...")
|
||||
|
||||
print(f"[SFT Data] {len(self.examples)} examples ready, {skipped} skipped")
|
||||
|
||||
def _build_shifted_labels(self, input_ids, target_ids):
|
||||
"""
|
||||
Walk through the token sequence and track whether we're in a user turn
|
||||
or assistant turn. Only keep labels for assistant response content.
|
||||
|
||||
Masking strategy (applied to the SHIFTED target):
|
||||
- Everything before and including <|assistant|>\\n: masked
|
||||
- Assistant response content and <|end|>: TRAIN
|
||||
- <|user|> and user content until next <|assistant|>: masked
|
||||
"""
|
||||
labels = [-100] * len(target_ids)
|
||||
in_assistant = False
|
||||
|
||||
for i, tid in enumerate(input_ids):
|
||||
if tid == self.assistant_token_id:
|
||||
# Next token after <|assistant|> is \n, then content starts
|
||||
in_assistant = True
|
||||
continue
|
||||
|
||||
if tid == self.user_token_id:
|
||||
in_assistant = False
|
||||
continue
|
||||
|
||||
if in_assistant:
|
||||
labels[i] = target_ids[i]
|
||||
|
||||
# When we hit <|end|> in assistant mode, include it then switch off
|
||||
if tid == self.end_token_id and in_assistant:
|
||||
in_assistant = False
|
||||
|
||||
return labels
|
||||
|
||||
def __len__(self):
|
||||
return len(self.examples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
input_ids, labels = self.examples[idx]
|
||||
return torch.tensor(input_ids, dtype=torch.long), torch.tensor(labels, dtype=torch.long)
|
||||
|
||||
|
||||
def sft_collate_fn(batch, pad_id=0):
|
||||
"""Pad sequences to the same length within a batch."""
|
||||
input_ids_list, labels_list = zip(*batch)
|
||||
max_len = max(ids.size(0) for ids in input_ids_list)
|
||||
|
||||
padded_inputs = []
|
||||
padded_labels = []
|
||||
for ids, lbl in zip(input_ids_list, labels_list):
|
||||
pad_len = max_len - ids.size(0)
|
||||
padded_inputs.append(torch.cat([ids, torch.full((pad_len,), pad_id, dtype=torch.long)]))
|
||||
padded_labels.append(torch.cat([lbl, torch.full((pad_len,), -100, dtype=torch.long)]))
|
||||
|
||||
return torch.stack(padded_inputs), torch.stack(padded_labels)
|
||||
|
||||
|
||||
def create_sft_dataloader(tokenizer, batch_size=4, max_seq_len=2048,
|
||||
cache_dir=None, max_samples=None, num_workers=4):
|
||||
dataset = SFTDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=max_seq_len,
|
||||
split="train_sft",
|
||||
cache_dir=cache_dir,
|
||||
max_samples=max_samples,
|
||||
)
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
), dataset
|
||||
163
training_code/model/transformer.py
Normal file
163
training_code/model/transformer.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
1B Parameter Decoder-Only Transformer — built from scratch.
|
||||
|
||||
Techniques:
|
||||
- RoPE (Rotary Position Embeddings)
|
||||
- Grouped Query Attention (GQA)
|
||||
- SwiGLU Feed-Forward
|
||||
- RMSNorm (pre-norm architecture)
|
||||
- Flash Attention 2 (via PyTorch SDPA)
|
||||
"""
|
||||
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .config import ModelConfig
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
norm = x.float().pow(2).mean(-1, keepdim=True).add(self.eps).rsqrt()
|
||||
return (x.float() * norm).type_as(x) * self.weight
|
||||
|
||||
|
||||
def precompute_rope_freqs(dim: int, max_seq_len: int, theta: float = 10000.0) -> torch.Tensor:
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
||||
t = torch.arange(max_seq_len, dtype=torch.float32)
|
||||
freqs = torch.outer(t, freqs)
|
||||
return torch.polar(torch.ones_like(freqs), freqs) # complex64
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
|
||||
B, S, H, D = xq.shape
|
||||
xq_c = torch.view_as_complex(xq.float().reshape(B, S, H, D // 2, 2))
|
||||
xk_c = torch.view_as_complex(xk.float().reshape(B, S, xk.shape[2], D // 2, 2))
|
||||
freqs = freqs_cis[:S].clone().unsqueeze(0).unsqueeze(2)
|
||||
xq_out = torch.view_as_real(xq_c * freqs).flatten(3)
|
||||
xk_out = torch.view_as_real(xk_c * freqs).flatten(3)
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class GroupedQueryAttention(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_kv_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_groups = self.num_heads // self.num_kv_heads
|
||||
|
||||
self.wq = nn.Linear(config.hidden_dim, self.num_heads * self.head_dim, bias=False)
|
||||
self.wk = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.wv = nn.Linear(config.hidden_dim, self.num_kv_heads * self.head_dim, bias=False)
|
||||
self.wo = nn.Linear(self.num_heads * self.head_dim, config.hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
B, S, _ = x.shape
|
||||
|
||||
q = self.wq(x).view(B, S, self.num_heads, self.head_dim)
|
||||
k = self.wk(x).view(B, S, self.num_kv_heads, self.head_dim)
|
||||
v = self.wv(x).view(B, S, self.num_kv_heads, self.head_dim)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
# Expand KV heads for GQA
|
||||
if self.num_groups > 1:
|
||||
k = k.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
|
||||
k = k.reshape(B, S, self.num_heads, self.head_dim)
|
||||
v = v.unsqueeze(3).expand(B, S, self.num_kv_heads, self.num_groups, self.head_dim)
|
||||
v = v.reshape(B, S, self.num_heads, self.head_dim)
|
||||
|
||||
# (B, num_heads, S, head_dim) for SDPA
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
out = out.transpose(1, 2).contiguous().view(B, S, -1)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.w_gate = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
||||
self.w_up = nn.Linear(config.hidden_dim, config.intermediate_dim, bias=False)
|
||||
self.w_down = nn.Linear(config.intermediate_dim, config.hidden_dim, bias=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.w_down(F.silu(self.w_gate(x)) * self.w_up(x))
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.attention_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.attention = GroupedQueryAttention(config)
|
||||
self.ffn_norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.ffn = SwiGLUFFN(config)
|
||||
|
||||
def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
x = x + self.attention(self.attention_norm(x), freqs_cis)
|
||||
x = x + self.ffn(self.ffn_norm(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: ModelConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.tok_embeddings = nn.Embedding(config.vocab_size, config.hidden_dim)
|
||||
self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_layers)])
|
||||
self.norm = RMSNorm(config.hidden_dim, eps=config.rms_norm_eps)
|
||||
self.output = nn.Linear(config.hidden_dim, config.vocab_size, bias=False)
|
||||
|
||||
# Pre-compute RoPE frequencies
|
||||
self.register_buffer(
|
||||
"freqs_cis",
|
||||
precompute_rope_freqs(config.head_dim, config.max_seq_len * 2, config.rope_theta),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
self._init_weights()
|
||||
|
||||
def _init_weights(self):
|
||||
"""Initialize with scaled normal, following GPT-NeoX / LLaMA conventions."""
|
||||
for module in self.modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
elif isinstance(module, nn.Embedding):
|
||||
nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||
|
||||
# Scale residual projections by 1/sqrt(2*num_layers)
|
||||
scale = (2 * self.config.num_layers) ** -0.5
|
||||
for layer in self.layers:
|
||||
nn.init.normal_(layer.attention.wo.weight, mean=0.0, std=0.02 * scale)
|
||||
nn.init.normal_(layer.ffn.w_down.weight, mean=0.0, std=0.02 * scale)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, targets: torch.Tensor = None):
|
||||
B, S = tokens.shape
|
||||
h = self.tok_embeddings(tokens)
|
||||
|
||||
freqs_cis = self.freqs_cis[:S]
|
||||
for layer in self.layers:
|
||||
h = layer(h, freqs_cis)
|
||||
h = self.norm(h)
|
||||
logits = self.output(h)
|
||||
|
||||
loss = None
|
||||
if targets is not None:
|
||||
loss = F.cross_entropy(
|
||||
logits.view(-1, logits.size(-1)),
|
||||
targets.view(-1),
|
||||
ignore_index=-100,
|
||||
)
|
||||
return logits, loss
|
||||
148
training_code/test_sft.py
Normal file
148
training_code/test_sft.py
Normal file
@@ -0,0 +1,148 @@
|
||||
"""Quick test of model quality with diverse prompts."""
|
||||
|
||||
import os, sys, time, torch
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
|
||||
DPO_CKPT = "/jfs/deepak-kumar/checkpoints_dpo/dpo_final.pt"
|
||||
SFT_CKPT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
|
||||
CHECKPOINT = DPO_CKPT if os.path.exists(DPO_CKPT) else SFT_CKPT
|
||||
DEVICE = "cuda:0"
|
||||
|
||||
USER_START = "<|user|>\n"
|
||||
ASST_START = "<|assistant|>\n"
|
||||
TURN_END = "\n<|end|>\n"
|
||||
|
||||
TEST_PROMPTS = [
|
||||
"Hi! How are you?",
|
||||
"What is photosynthesis?",
|
||||
"Explain gravity to a 5-year-old.",
|
||||
"Write a short poem about the ocean.",
|
||||
"What are the three states of matter?",
|
||||
"How does a computer work?",
|
||||
"What is the capital of France and why is it famous?",
|
||||
"Give me 3 tips for learning a new language.",
|
||||
"What is machine learning in simple terms?",
|
||||
]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(model, tokenizer, prompt, max_new_tokens=256,
|
||||
temperature=0.7, top_k=50, top_p=0.9, repetition_penalty=1.15):
|
||||
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
input_ids = torch.tensor([input_ids], dtype=torch.long, device=DEVICE)
|
||||
generated = []
|
||||
eos_id = tokenizer.eos_token_id
|
||||
|
||||
end_token_ids = tokenizer.encode("<|end|>", add_special_tokens=False)
|
||||
end_id = end_token_ids[0] if end_token_ids else None
|
||||
user_token_ids = tokenizer.encode("<|user|>", add_special_tokens=False)
|
||||
user_id = user_token_ids[0] if user_token_ids else None
|
||||
|
||||
stop_ids = set()
|
||||
if eos_id is not None:
|
||||
stop_ids.add(eos_id)
|
||||
if end_id is not None:
|
||||
stop_ids.add(end_id)
|
||||
if user_id is not None:
|
||||
stop_ids.add(user_id)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(input_ids)
|
||||
|
||||
logits = logits[:, -1, :].float()
|
||||
|
||||
if repetition_penalty != 1.0 and generated:
|
||||
for tid in set(generated):
|
||||
if logits[0, tid] > 0:
|
||||
logits[0, tid] /= repetition_penalty
|
||||
else:
|
||||
logits[0, tid] *= repetition_penalty
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k > 0:
|
||||
topk_vals, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
logits[logits < topk_vals[:, -1:]] = float('-inf')
|
||||
|
||||
if top_p < 1.0:
|
||||
sorted_logits, sorted_idx = torch.sort(logits, descending=True)
|
||||
cumulative = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
remove = cumulative - torch.softmax(sorted_logits, dim=-1) > top_p
|
||||
sorted_logits[remove] = float('-inf')
|
||||
logits = sorted_logits.scatter(1, sorted_idx, sorted_logits)
|
||||
|
||||
probs = torch.softmax(logits, dim=-1)
|
||||
next_token = torch.multinomial(probs, 1)
|
||||
token_id = next_token.item()
|
||||
|
||||
if token_id in stop_ids:
|
||||
break
|
||||
|
||||
generated.append(token_id)
|
||||
input_ids = torch.cat([input_ids, next_token], dim=1)
|
||||
|
||||
if input_ids.size(1) > 2048:
|
||||
break
|
||||
|
||||
return tokenizer.decode(generated, skip_special_tokens=True)
|
||||
|
||||
|
||||
def main():
|
||||
ckpt_name = "DPO" if "dpo" in CHECKPOINT else "SFT"
|
||||
print("=" * 70)
|
||||
print(" " + ckpt_name + " MODEL TEST")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
config = ModelConfig()
|
||||
config.vocab_size = len(tokenizer)
|
||||
model = Transformer(config)
|
||||
|
||||
print("")
|
||||
print("Loading checkpoint: " + CHECKPOINT)
|
||||
ckpt = torch.load(CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
step = ckpt.get("step", "?")
|
||||
del ckpt
|
||||
|
||||
model = model.to(DEVICE).bfloat16().eval()
|
||||
print("Model loaded (" + ckpt_name + " step " + str(step) + ", vocab " + str(config.vocab_size) + ")")
|
||||
mem = torch.cuda.max_memory_allocated(DEVICE) / 1e9
|
||||
print("GPU memory: " + str(round(mem, 1)) + " GB")
|
||||
print("-" * 70)
|
||||
|
||||
for i, question in enumerate(TEST_PROMPTS, 1):
|
||||
prompt = USER_START + question + TURN_END + ASST_START
|
||||
|
||||
print("")
|
||||
print("[Test " + str(i) + "/" + str(len(TEST_PROMPTS)) + "]")
|
||||
print(" Q: " + question)
|
||||
|
||||
t0 = time.time()
|
||||
response = generate(model, tokenizer, prompt)
|
||||
dt = time.time() - t0
|
||||
tokens = len(tokenizer.encode(response, add_special_tokens=False))
|
||||
|
||||
response = response.split("<|end|>")[0].split("<|user|>")[0].strip()
|
||||
|
||||
print(" A: " + response)
|
||||
tps = int(tokens / max(dt, 0.01))
|
||||
print(" [" + str(tokens) + " tokens, " + str(round(dt, 1)) + "s, " + str(tps) + " tok/s]")
|
||||
print("-" * 70)
|
||||
|
||||
print("")
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
257
training_code/train.py
Normal file
257
training_code/train.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""
|
||||
Distributed training script for 1B parameter Transformer.
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train.py
|
||||
|
||||
Stack: PyTorch DDP + BF16 autocast + 8x H100 80GB
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig, TrainConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer, create_dataloader
|
||||
|
||||
|
||||
def get_wsd_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
"""Warmup-Stable-Decay: linear warmup -> constant -> cosine decay (last 20%)."""
|
||||
stable_end = int(total_steps * 0.8)
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
elif step < stable_end:
|
||||
return max_lr
|
||||
else:
|
||||
progress = (step - stable_end) / max(total_steps - stable_end, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def find_latest_checkpoint(checkpoint_dir):
|
||||
"""Find the latest step_*.pt checkpoint in the directory."""
|
||||
import glob
|
||||
pattern = os.path.join(checkpoint_dir, "step_*.pt")
|
||||
files = glob.glob(pattern)
|
||||
if not files:
|
||||
return None, 0
|
||||
latest = max(files, key=lambda f: int(os.path.basename(f).replace("step_", "").replace(".pt", "")))
|
||||
step = int(os.path.basename(latest).replace("step_", "").replace(".pt", ""))
|
||||
return latest, step
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
model_config = ModelConfig()
|
||||
train_config = TrainConfig()
|
||||
|
||||
eff_batch = train_config.batch_size_per_gpu * world_size * train_config.gradient_accumulation_steps
|
||||
tokens_per_step = eff_batch * model_config.max_seq_len
|
||||
total_steps = train_config.total_tokens // tokens_per_step
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(train_config.log_dir, exist_ok=True)
|
||||
os.makedirs(train_config.checkpoint_dir, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(f" TRAINING 1B TRANSFORMER FROM SCRATCH")
|
||||
print(f" Arch: {model_config.num_layers}L / {model_config.hidden_dim}D / "
|
||||
f"{model_config.num_attention_heads}H / GQA-{model_config.num_kv_heads}KV / "
|
||||
f"SwiGLU-{model_config.intermediate_dim}")
|
||||
print(f" Seq: {model_config.max_seq_len} | Vocab: {model_config.vocab_size}")
|
||||
print(f" GPUs: {world_size}x H100 80GB | Backend: DDP + BF16 autocast")
|
||||
print(f" Batch: {eff_batch} seqs = {tokens_per_step:,} tok/step")
|
||||
print(f" Steps: {total_steps:,} | Target: {train_config.total_tokens:,} tokens")
|
||||
print("=" * 70)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Model
|
||||
torch.manual_seed(train_config.seed)
|
||||
model = Transformer(model_config).to(device)
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
print(f"[Init] Params: {n:,} ({n/1e9:.3f}B)")
|
||||
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
# Optimizer
|
||||
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": train_config.weight_decay},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=train_config.learning_rate, betas=(train_config.beta1, train_config.beta2), fused=True)
|
||||
|
||||
if rank == 0:
|
||||
dp = sum(p.numel() for p in decay_params)
|
||||
ndp = sum(p.numel() for p in nodecay_params)
|
||||
print(f"[Init] Optimizer: {dp:,} decay + {ndp:,} no-decay params")
|
||||
|
||||
# Resume from checkpoint
|
||||
resume_step = 0
|
||||
ckpt_path, ckpt_step = find_latest_checkpoint(train_config.checkpoint_dir)
|
||||
if ckpt_path is not None:
|
||||
if rank == 0:
|
||||
print(f"[Resume] Loading checkpoint: {ckpt_path} (step {ckpt_step})")
|
||||
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
|
||||
model.module.load_state_dict(ckpt["model"])
|
||||
optimizer.load_state_dict(ckpt["optimizer"])
|
||||
resume_step = ckpt["step"]
|
||||
if rank == 0:
|
||||
print(f"[Resume] Restored model + optimizer at step {resume_step}, "
|
||||
f"loss was {ckpt.get('loss', 'N/A')}")
|
||||
del ckpt
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
if rank == 0:
|
||||
print("[Init] No checkpoint found, starting from scratch")
|
||||
|
||||
# Data — use (seed + resume_step) so resumed runs see different shuffled data
|
||||
effective_seed = train_config.seed + resume_step
|
||||
dataloader = create_dataloader(tokenizer, train_config, rank=rank, world_size=world_size,
|
||||
seed_override=effective_seed)
|
||||
data_iter = iter(dataloader)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Dataloader ready (streaming FineWeb-Edu 10BT)")
|
||||
print(f"[Schedule] WSD: warmup {train_config.warmup_steps} -> "
|
||||
f"stable {int(total_steps*0.8)} -> decay {total_steps}")
|
||||
if resume_step > 0:
|
||||
remaining = total_steps - resume_step
|
||||
print(f"[Resume] Continuing from step {resume_step}, {remaining:,} steps remaining")
|
||||
print("-" * 70)
|
||||
sys.stdout.flush()
|
||||
|
||||
# ===== TRAINING LOOP =====
|
||||
model.train()
|
||||
global_step = resume_step
|
||||
running_loss = 0.0
|
||||
best_loss = float("inf")
|
||||
tokens_done = resume_step * tokens_per_step
|
||||
t0 = time.time()
|
||||
step_t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(train_config.log_dir, "train_log.jsonl"), "a") if rank == 0 else None
|
||||
|
||||
while global_step < total_steps:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
micro_loss = 0.0
|
||||
|
||||
for micro in range(train_config.gradient_accumulation_steps):
|
||||
try:
|
||||
input_ids, labels = next(data_iter)
|
||||
except StopIteration:
|
||||
data_iter = iter(dataloader)
|
||||
input_ids, labels = next(data_iter)
|
||||
|
||||
input_ids = input_ids.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True)
|
||||
|
||||
# BF16 autocast — no scaler needed (BF16 has enough dynamic range)
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
_, loss = model(input_ids, labels)
|
||||
loss = loss / train_config.gradient_accumulation_steps
|
||||
|
||||
loss.backward()
|
||||
micro_loss += loss.item()
|
||||
|
||||
# Gradient clipping
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.grad_clip)
|
||||
|
||||
# LR schedule
|
||||
lr = get_wsd_lr(global_step, train_config.warmup_steps, total_steps,
|
||||
train_config.learning_rate, train_config.min_lr)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += micro_loss
|
||||
tokens_done += tokens_per_step
|
||||
|
||||
# Log
|
||||
if global_step % train_config.log_interval == 0:
|
||||
dt = time.time() - step_t0
|
||||
tps = (train_config.log_interval * tokens_per_step) / max(dt, 1e-9)
|
||||
avg = running_loss / train_config.log_interval
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
print(
|
||||
f"[Step {global_step:>6d}/{total_steps}] "
|
||||
f"loss={avg:.4f} | lr={lr:.2e} | "
|
||||
f"tok/s={tps:,.0f} | GPU={gpu_mem:.1f}GB | "
|
||||
f"{pct:.1f}% | ETA={eta/3600:.1f}h",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "loss": round(avg, 4), "lr": lr,
|
||||
"tps": round(tps), "tokens": tokens_done,
|
||||
"gpu_gb": round(gpu_mem, 1), "elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
if avg < best_loss:
|
||||
best_loss = avg
|
||||
running_loss = 0.0
|
||||
step_t0 = time.time()
|
||||
|
||||
# Checkpoint
|
||||
if global_step % train_config.save_interval == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
ckpt_path = os.path.join(train_config.checkpoint_dir, f"step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"loss": avg if global_step % train_config.log_interval == 0 else micro_loss,
|
||||
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
|
||||
}, ckpt_path)
|
||||
print(f" >> Checkpoint: {ckpt_path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(train_config.checkpoint_dir, "final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": {"model": model_config.__dict__, "train": train_config.__dict__},
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" TRAINING COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Tokens: {tokens_done:,}")
|
||||
print(f" Time: {total_time/3600:.2f}h | Throughput: {tokens_done/total_time:,.0f} tok/s")
|
||||
print(f" Best loss: {best_loss:.4f}")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
327
training_code/train_dpo.py
Normal file
327
training_code/train_dpo.py
Normal file
@@ -0,0 +1,327 @@
|
||||
"""
|
||||
DPO (Direct Preference Optimization) training for the 1B Transformer.
|
||||
|
||||
Takes the SFT model and aligns it with human preferences using
|
||||
UltraFeedback preference pairs.
|
||||
|
||||
DPO Loss:
|
||||
L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x)))
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train_dpo.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
from model.dpo_data import DPODataset, dpo_collate_fn
|
||||
|
||||
|
||||
# === Config ===
|
||||
SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
|
||||
DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo"
|
||||
LOG_DIR = "/home/jovyan/training/logs"
|
||||
DATA_CACHE = "/jfs/deepak-kumar/data"
|
||||
|
||||
NUM_EPOCHS = 1
|
||||
BATCH_SIZE_PER_GPU = 2
|
||||
GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64
|
||||
MAX_SEQ_LEN = 1024
|
||||
LEARNING_RATE = 5e-7 # very low LR for DPO
|
||||
MIN_LR = 1e-7
|
||||
WARMUP_STEPS = 100
|
||||
WEIGHT_DECAY = 0.01
|
||||
GRAD_CLIP = 1.0
|
||||
BETA = 0.1 # DPO temperature
|
||||
LOG_INTERVAL = 10
|
||||
SAVE_INTERVAL = 200
|
||||
|
||||
|
||||
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def get_per_token_logps(model, input_ids, prompt_lens):
|
||||
"""
|
||||
Compute sum of log probabilities for response tokens only.
|
||||
input_ids: [B, S] full sequence (prompt + response)
|
||||
prompt_lens: [B] where response starts
|
||||
Returns: [B] sum of log probs over response tokens
|
||||
"""
|
||||
# Clone input to avoid inplace issues with shared RoPE buffers
|
||||
inp = input_ids[:, :-1].contiguous()
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
logits, _ = model(inp)
|
||||
|
||||
labels = input_ids[:, 1:].contiguous()
|
||||
log_probs = F.log_softmax(logits.float(), dim=-1)
|
||||
token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
|
||||
|
||||
B, S = token_logps.shape
|
||||
mask = torch.zeros_like(token_logps)
|
||||
for b in range(B):
|
||||
pl = prompt_lens[b].item()
|
||||
response_start = max(0, pl - 1)
|
||||
seq_len = (labels[b] != 0).sum().item()
|
||||
mask[b, response_start:seq_len] = 1.0
|
||||
|
||||
return (token_logps * mask).sum(dim=1)
|
||||
|
||||
|
||||
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
|
||||
ref_chosen_logps, ref_rejected_logps, beta=0.1):
|
||||
"""Compute DPO loss and metrics."""
|
||||
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
|
||||
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
|
||||
|
||||
logits = chosen_rewards - rejected_rewards
|
||||
loss = -F.logsigmoid(logits).mean()
|
||||
|
||||
with torch.no_grad():
|
||||
chosen_better = (chosen_rewards > rejected_rewards).float().mean()
|
||||
reward_margin = (chosen_rewards - rejected_rewards).mean()
|
||||
|
||||
return loss, chosen_better.item(), reward_margin.item()
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True)
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER")
|
||||
print("=" * 70)
|
||||
|
||||
tokenizer = get_tokenizer()
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
model_config = ModelConfig()
|
||||
model_config.vocab_size = len(tokenizer)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}")
|
||||
|
||||
# Policy model (trainable)
|
||||
policy = Transformer(model_config)
|
||||
ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
policy.load_state_dict(ckpt["model"])
|
||||
sft_step = ckpt.get("step", 0)
|
||||
if rank == 0:
|
||||
print(f"[Init] SFT model loaded (step {sft_step})")
|
||||
|
||||
# Reference model (frozen copy)
|
||||
ref_model = Transformer(model_config)
|
||||
ref_model.load_state_dict(ckpt["model"])
|
||||
del ckpt
|
||||
|
||||
policy = policy.to(device)
|
||||
ref_model = ref_model.to(device).bfloat16()
|
||||
ref_model.eval()
|
||||
for p in ref_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
policy = DDP(policy, device_ids=[local_rank])
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in policy.parameters())
|
||||
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
|
||||
print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}")
|
||||
|
||||
# Dataset
|
||||
dataset = DPODataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
split="train",
|
||||
cache_dir=DATA_CACHE,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE_PER_GPU,
|
||||
sampler=sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
|
||||
if rank == 0:
|
||||
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
|
||||
print(f"[Init] Dataset: {len(dataset):,} preference pairs")
|
||||
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
|
||||
print(f"[Init] Total steps: {total_steps}")
|
||||
print("-" * 70)
|
||||
|
||||
decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
|
||||
|
||||
policy.train()
|
||||
global_step = 0
|
||||
running_loss = 0.0
|
||||
running_acc = 0.0
|
||||
running_margin = 0.0
|
||||
t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
sampler.set_epoch(epoch)
|
||||
data_iter = iter(dataloader)
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
|
||||
|
||||
while True:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
batch_loss = 0.0
|
||||
batch_acc = 0.0
|
||||
batch_margin = 0.0
|
||||
valid_micros = 0
|
||||
|
||||
for _ in range(GRADIENT_ACCUMULATION):
|
||||
try:
|
||||
batch = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
chosen_ids = batch["chosen_ids"].to(device, non_blocking=True)
|
||||
rejected_ids = batch["rejected_ids"].to(device, non_blocking=True)
|
||||
prompt_lens = batch["prompt_lens"].to(device, non_blocking=True)
|
||||
|
||||
policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens)
|
||||
policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens)
|
||||
|
||||
with torch.no_grad():
|
||||
ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens)
|
||||
ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens)
|
||||
|
||||
loss, acc, margin = dpo_loss(
|
||||
policy_chosen_logps, policy_rejected_logps,
|
||||
ref_chosen_logps, ref_rejected_logps,
|
||||
beta=BETA,
|
||||
)
|
||||
loss = loss / GRADIENT_ACCUMULATION
|
||||
loss.backward()
|
||||
|
||||
batch_loss += loss.item()
|
||||
batch_acc += acc
|
||||
batch_margin += margin
|
||||
valid_micros += 1
|
||||
|
||||
if valid_micros == 0:
|
||||
break
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP)
|
||||
|
||||
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += batch_loss
|
||||
running_acc += batch_acc / valid_micros
|
||||
running_margin += batch_margin / valid_micros
|
||||
|
||||
if global_step % LOG_INTERVAL == 0:
|
||||
avg_loss = running_loss / LOG_INTERVAL
|
||||
avg_acc = running_acc / LOG_INTERVAL
|
||||
avg_margin = running_margin / LOG_INTERVAL
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
print(
|
||||
f" [Step {global_step:>5d}/{total_steps}] "
|
||||
f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | "
|
||||
f"margin={avg_margin:.3f} | lr={lr:.2e} | "
|
||||
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "loss": round(avg_loss, 4),
|
||||
"accuracy": round(avg_acc, 4),
|
||||
"reward_margin": round(avg_margin, 4),
|
||||
"lr": lr, "elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
running_loss = 0.0
|
||||
running_acc = 0.0
|
||||
running_margin = 0.0
|
||||
|
||||
if global_step % SAVE_INTERVAL == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": policy.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": model_config.vocab_size,
|
||||
}, path)
|
||||
print(f" >> Checkpoint: {path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final save
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": policy.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": model_config.vocab_size,
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" DPO COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
|
||||
print(f" Time: {total_time/60:.1f} minutes")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
272
training_code/train_sft.py
Normal file
272
training_code/train_sft.py
Normal file
@@ -0,0 +1,272 @@
|
||||
"""
|
||||
SFT (Supervised Fine-Tuning) script for the 1B Transformer.
|
||||
|
||||
Takes the pretrained base model and fine-tunes it on instruction-response
|
||||
conversations from UltraChat 200K.
|
||||
|
||||
Launch: torchrun --nproc_per_node=8 train_sft.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import math
|
||||
import time
|
||||
import json
|
||||
import datetime
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
from model.config import ModelConfig
|
||||
from model.transformer import Transformer
|
||||
from model.data import get_tokenizer
|
||||
from model.sft_data import SFTDataset, sft_collate_fn
|
||||
|
||||
|
||||
# === Config ===
|
||||
BASE_CHECKPOINT = "/jfs/deepak-kumar/checkpoints/step_19000.pt"
|
||||
SFT_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_sft"
|
||||
LOG_DIR = "/home/jovyan/training/logs"
|
||||
DATA_CACHE = "/jfs/deepak-kumar/data"
|
||||
|
||||
NUM_EPOCHS = 2
|
||||
BATCH_SIZE_PER_GPU = 4
|
||||
GRADIENT_ACCUMULATION = 4 # effective batch = 4 * 8 * 4 = 128
|
||||
MAX_SEQ_LEN = 2048
|
||||
LEARNING_RATE = 2e-5 # much lower than pretraining — we're fine-tuning
|
||||
MIN_LR = 2e-6
|
||||
WARMUP_STEPS = 200
|
||||
WEIGHT_DECAY = 0.01
|
||||
GRAD_CLIP = 1.0
|
||||
LOG_INTERVAL = 10
|
||||
SAVE_INTERVAL = 500
|
||||
|
||||
|
||||
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
|
||||
if step < warmup_steps:
|
||||
return max_lr * step / max(warmup_steps, 1)
|
||||
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
|
||||
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def main():
|
||||
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
if rank == 0:
|
||||
os.makedirs(SFT_CHECKPOINT_DIR, exist_ok=True)
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
print("=" * 70)
|
||||
print(" SFT: INSTRUCTION FINE-TUNING 1B TRANSFORMER")
|
||||
print("=" * 70)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer = get_tokenizer()
|
||||
|
||||
# Load base model
|
||||
model_config = ModelConfig()
|
||||
torch.manual_seed(42)
|
||||
model = Transformer(model_config)
|
||||
|
||||
if rank == 0:
|
||||
print(f"[Init] Loading base model from {BASE_CHECKPOINT}")
|
||||
ckpt = torch.load(BASE_CHECKPOINT, map_location="cpu", weights_only=False)
|
||||
model.load_state_dict(ckpt["model"])
|
||||
base_step = ckpt.get("step", 0)
|
||||
base_loss = ckpt.get("loss", "?")
|
||||
if rank == 0:
|
||||
print(f"[Init] Base model: step={base_step}, pretrain_loss={base_loss}")
|
||||
del ckpt
|
||||
|
||||
# Add chat tokens to embedding — expand vocab if needed
|
||||
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
|
||||
vocab = tokenizer.get_vocab()
|
||||
new_tokens = [t for t in special_tokens if t not in vocab]
|
||||
if new_tokens:
|
||||
tokenizer.add_tokens(new_tokens, special_tokens=True)
|
||||
|
||||
new_vocab_size = len(tokenizer)
|
||||
if new_vocab_size > model_config.vocab_size:
|
||||
if rank == 0:
|
||||
print(f"[Init] Expanding vocab: {model_config.vocab_size} -> {new_vocab_size}")
|
||||
|
||||
old_emb_weight = model.tok_embeddings.weight.data
|
||||
model.tok_embeddings = torch.nn.Embedding(new_vocab_size, model_config.hidden_dim)
|
||||
model.tok_embeddings.weight.data[:model_config.vocab_size] = old_emb_weight
|
||||
# Init new token embeddings as mean of existing (better than random)
|
||||
mean_emb = old_emb_weight.mean(dim=0)
|
||||
for i in range(model_config.vocab_size, new_vocab_size):
|
||||
model.tok_embeddings.weight.data[i] = mean_emb
|
||||
|
||||
old_output_weight = model.output.weight.data
|
||||
model.output = torch.nn.Linear(model_config.hidden_dim, new_vocab_size, bias=False)
|
||||
model.output.weight.data[:model_config.vocab_size] = old_output_weight
|
||||
|
||||
model.config.vocab_size = new_vocab_size
|
||||
|
||||
model = model.to(device)
|
||||
model = DDP(model, device_ids=[local_rank])
|
||||
|
||||
if rank == 0:
|
||||
n = sum(p.numel() for p in model.parameters())
|
||||
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
|
||||
|
||||
# Dataset (only load on each process)
|
||||
dataset = SFTDataset(
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=MAX_SEQ_LEN,
|
||||
split="train_sft",
|
||||
cache_dir=DATA_CACHE,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE_PER_GPU,
|
||||
sampler=sampler,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=lambda b: sft_collate_fn(b, pad_id=tokenizer.pad_token_id),
|
||||
)
|
||||
|
||||
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
|
||||
total_steps = steps_per_epoch * NUM_EPOCHS
|
||||
|
||||
if rank == 0:
|
||||
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
|
||||
print(f"[Init] Dataset: {len(dataset):,} examples")
|
||||
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
|
||||
print(f"[Init] Total steps: {total_steps} | Epochs: {NUM_EPOCHS}")
|
||||
print(f"[Init] LR: {LEARNING_RATE} → {MIN_LR} (cosine)")
|
||||
print("-" * 70)
|
||||
|
||||
# Optimizer — lower LR for fine-tuning
|
||||
decay_params = [p for n, p in model.named_parameters() if p.dim() >= 2 and p.requires_grad]
|
||||
nodecay_params = [p for n, p in model.named_parameters() if p.dim() < 2 and p.requires_grad]
|
||||
optimizer = torch.optim.AdamW([
|
||||
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
|
||||
{"params": nodecay_params, "weight_decay": 0.0},
|
||||
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
|
||||
|
||||
# Training
|
||||
model.train()
|
||||
global_step = 0
|
||||
running_loss = 0.0
|
||||
t0 = time.time()
|
||||
step_t0 = time.time()
|
||||
|
||||
log_file = open(os.path.join(LOG_DIR, "sft_log.jsonl"), "w") if rank == 0 else None
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
sampler.set_epoch(epoch)
|
||||
data_iter = iter(dataloader)
|
||||
micro_step = 0
|
||||
|
||||
if rank == 0:
|
||||
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
|
||||
|
||||
while True:
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
batch_loss = 0.0
|
||||
|
||||
for _ in range(GRADIENT_ACCUMULATION):
|
||||
try:
|
||||
input_ids, labels = next(data_iter)
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
input_ids = input_ids.to(device, non_blocking=True)
|
||||
labels = labels.to(device, non_blocking=True)
|
||||
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
_, loss = model(input_ids, labels)
|
||||
loss = loss / GRADIENT_ACCUMULATION
|
||||
|
||||
loss.backward()
|
||||
batch_loss += loss.item()
|
||||
micro_step += 1
|
||||
|
||||
if batch_loss == 0:
|
||||
break
|
||||
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
|
||||
|
||||
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
|
||||
for pg in optimizer.param_groups:
|
||||
pg["lr"] = lr
|
||||
|
||||
optimizer.step()
|
||||
global_step += 1
|
||||
running_loss += batch_loss
|
||||
|
||||
if global_step % LOG_INTERVAL == 0:
|
||||
dt = time.time() - step_t0
|
||||
avg = running_loss / LOG_INTERVAL
|
||||
elapsed = time.time() - t0
|
||||
pct = 100.0 * global_step / total_steps
|
||||
|
||||
if rank == 0:
|
||||
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
|
||||
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
|
||||
print(
|
||||
f" [Step {global_step:>5d}/{total_steps}] "
|
||||
f"loss={avg:.4f} | lr={lr:.2e} | "
|
||||
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
|
||||
flush=True,
|
||||
)
|
||||
if log_file:
|
||||
log_file.write(json.dumps({
|
||||
"step": global_step, "epoch": epoch + 1,
|
||||
"loss": round(avg, 4), "lr": lr,
|
||||
"elapsed_s": round(elapsed, 1),
|
||||
}) + "\n")
|
||||
log_file.flush()
|
||||
|
||||
running_loss = 0.0
|
||||
step_t0 = time.time()
|
||||
|
||||
if global_step % SAVE_INTERVAL == 0:
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
path = os.path.join(SFT_CHECKPOINT_DIR, f"sft_step_{global_step}.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": new_vocab_size,
|
||||
}, path)
|
||||
print(f" >> Checkpoint: {path}", flush=True)
|
||||
dist.barrier()
|
||||
|
||||
# Final save
|
||||
dist.barrier()
|
||||
if rank == 0:
|
||||
final_path = os.path.join(SFT_CHECKPOINT_DIR, "sft_final.pt")
|
||||
torch.save({
|
||||
"step": global_step,
|
||||
"model": model.module.state_dict(),
|
||||
"config": model_config.__dict__,
|
||||
"vocab_size": new_vocab_size,
|
||||
}, final_path)
|
||||
total_time = time.time() - t0
|
||||
print("=" * 70)
|
||||
print(f" SFT COMPLETE")
|
||||
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
|
||||
print(f" Time: {total_time/60:.1f} minutes")
|
||||
print(f" Final model: {final_path}")
|
||||
print("=" * 70)
|
||||
if log_file:
|
||||
log_file.close()
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
95
training_logs/dpo_log.jsonl
Normal file
95
training_logs/dpo_log.jsonl
Normal file
@@ -0,0 +1,95 @@
|
||||
{"step": 10, "loss": 0.7798, "accuracy": 0.4375, "reward_margin": 0.0867, "lr": 4.5e-08, "elapsed_s": 11.5}
|
||||
{"step": 20, "loss": 0.7714, "accuracy": 0.375, "reward_margin": -0.0647, "lr": 9.499999999999999e-08, "elapsed_s": 19.4}
|
||||
{"step": 30, "loss": 0.7522, "accuracy": 0.4125, "reward_margin": -0.0699, "lr": 1.45e-07, "elapsed_s": 27.2}
|
||||
{"step": 40, "loss": 0.6921, "accuracy": 0.525, "reward_margin": 0.0268, "lr": 1.9499999999999999e-07, "elapsed_s": 35.0}
|
||||
{"step": 50, "loss": 0.6667, "accuracy": 0.525, "reward_margin": 0.1128, "lr": 2.45e-07, "elapsed_s": 42.7}
|
||||
{"step": 60, "loss": 0.6777, "accuracy": 0.4625, "reward_margin": 0.0764, "lr": 2.95e-07, "elapsed_s": 50.5}
|
||||
{"step": 70, "loss": 0.6515, "accuracy": 0.55, "reward_margin": 0.1272, "lr": 3.45e-07, "elapsed_s": 58.3}
|
||||
{"step": 80, "loss": 0.6387, "accuracy": 0.6, "reward_margin": 0.1564, "lr": 3.95e-07, "elapsed_s": 66.2}
|
||||
{"step": 90, "loss": 0.586, "accuracy": 0.675, "reward_margin": 0.2776, "lr": 4.45e-07, "elapsed_s": 74.0}
|
||||
{"step": 100, "loss": 0.5986, "accuracy": 0.675, "reward_margin": 0.2868, "lr": 4.949999999999999e-07, "elapsed_s": 81.8}
|
||||
{"step": 110, "loss": 0.6282, "accuracy": 0.6125, "reward_margin": 0.2265, "lr": 4.998898801231603e-07, "elapsed_s": 89.6}
|
||||
{"step": 120, "loss": 0.5832, "accuracy": 0.6375, "reward_margin": 0.3602, "lr": 4.995093745023968e-07, "elapsed_s": 97.5}
|
||||
{"step": 130, "loss": 0.6263, "accuracy": 0.6375, "reward_margin": 0.2447, "lr": 4.988576407978475e-07, "elapsed_s": 105.3}
|
||||
{"step": 140, "loss": 0.5755, "accuracy": 0.7125, "reward_margin": 0.4094, "lr": 4.979355650254416e-07, "elapsed_s": 113.0}
|
||||
{"step": 150, "loss": 0.6593, "accuracy": 0.6125, "reward_margin": 0.271, "lr": 4.967444007244951e-07, "elapsed_s": 120.8}
|
||||
{"step": 160, "loss": 0.6044, "accuracy": 0.6, "reward_margin": 0.3641, "lr": 4.952857672535551e-07, "elapsed_s": 128.7}
|
||||
{"step": 170, "loss": 0.5297, "accuracy": 0.7, "reward_margin": 0.5777, "lr": 4.935616475889216e-07, "elapsed_s": 136.5}
|
||||
{"step": 180, "loss": 0.6645, "accuracy": 0.625, "reward_margin": 0.3387, "lr": 4.9157438562884e-07, "elapsed_s": 144.3}
|
||||
{"step": 190, "loss": 0.5507, "accuracy": 0.625, "reward_margin": 0.509, "lr": 4.893266830070295e-07, "elapsed_s": 152.1}
|
||||
{"step": 200, "loss": 0.5935, "accuracy": 0.6125, "reward_margin": 0.4919, "lr": 4.86821595419878e-07, "elapsed_s": 159.8}
|
||||
{"step": 210, "loss": 0.5545, "accuracy": 0.6125, "reward_margin": 0.486, "lr": 4.840625284722983e-07, "elapsed_s": 179.6}
|
||||
{"step": 220, "loss": 0.601, "accuracy": 0.6, "reward_margin": 0.3861, "lr": 4.810532330478923e-07, "elapsed_s": 187.5}
|
||||
{"step": 230, "loss": 0.5981, "accuracy": 0.5875, "reward_margin": 0.3972, "lr": 4.777978002097169e-07, "elapsed_s": 195.4}
|
||||
{"step": 240, "loss": 0.552, "accuracy": 0.6625, "reward_margin": 0.5146, "lr": 4.743006556385841e-07, "elapsed_s": 203.3}
|
||||
{"step": 250, "loss": 0.5747, "accuracy": 0.6, "reward_margin": 0.5023, "lr": 4.7056655361645756e-07, "elapsed_s": 211.1}
|
||||
{"step": 260, "loss": 0.6258, "accuracy": 0.575, "reward_margin": 0.3457, "lr": 4.666005705631227e-07, "elapsed_s": 219.0}
|
||||
{"step": 270, "loss": 0.5409, "accuracy": 0.6625, "reward_margin": 0.7137, "lr": 4.6240809813491944e-07, "elapsed_s": 226.9}
|
||||
{"step": 280, "loss": 0.6261, "accuracy": 0.55, "reward_margin": 0.3915, "lr": 4.579948358949176e-07, "elapsed_s": 234.7}
|
||||
{"step": 290, "loss": 0.5672, "accuracy": 0.5875, "reward_margin": 0.5594, "lr": 4.5336678356450174e-07, "elapsed_s": 242.6}
|
||||
{"step": 300, "loss": 0.5656, "accuracy": 0.65, "reward_margin": 0.5763, "lr": 4.485302328668972e-07, "elapsed_s": 250.5}
|
||||
{"step": 310, "loss": 0.5736, "accuracy": 0.6625, "reward_margin": 0.5858, "lr": 4.4349175897372746e-07, "elapsed_s": 258.4}
|
||||
{"step": 320, "loss": 0.5536, "accuracy": 0.675, "reward_margin": 0.5922, "lr": 4.3825821156623e-07, "elapsed_s": 266.3}
|
||||
{"step": 330, "loss": 0.5818, "accuracy": 0.65, "reward_margin": 0.4938, "lr": 4.328367055232836e-07, "elapsed_s": 276.1}
|
||||
{"step": 340, "loss": 0.517, "accuracy": 0.7375, "reward_margin": 0.8002, "lr": 4.2723461124890523e-07, "elapsed_s": 284.0}
|
||||
{"step": 350, "loss": 0.54, "accuracy": 0.7375, "reward_margin": 0.6545, "lr": 4.2145954465236736e-07, "elapsed_s": 291.8}
|
||||
{"step": 360, "loss": 0.5317, "accuracy": 0.7125, "reward_margin": 0.6229, "lr": 4.155193567945568e-07, "elapsed_s": 299.6}
|
||||
{"step": 370, "loss": 0.6032, "accuracy": 0.65, "reward_margin": 0.4373, "lr": 4.094221232146508e-07, "elapsed_s": 307.5}
|
||||
{"step": 380, "loss": 0.6854, "accuracy": 0.575, "reward_margin": 0.2799, "lr": 4.0317613295162e-07, "elapsed_s": 315.4}
|
||||
{"step": 390, "loss": 0.5968, "accuracy": 0.6125, "reward_margin": 0.4253, "lr": 3.967898772754842e-07, "elapsed_s": 323.3}
|
||||
{"step": 400, "loss": 0.5062, "accuracy": 0.7625, "reward_margin": 0.7942, "lr": 3.9027203814363984e-07, "elapsed_s": 331.2}
|
||||
{"step": 410, "loss": 0.648, "accuracy": 0.5625, "reward_margin": 0.2849, "lr": 3.8363147639795234e-07, "elapsed_s": 353.3}
|
||||
{"step": 420, "loss": 0.6355, "accuracy": 0.625, "reward_margin": 0.4522, "lr": 3.7687721971866007e-07, "elapsed_s": 361.3}
|
||||
{"step": 430, "loss": 0.5253, "accuracy": 0.7, "reward_margin": 0.5549, "lr": 3.7001845035146485e-07, "elapsed_s": 369.2}
|
||||
{"step": 440, "loss": 0.5343, "accuracy": 0.7125, "reward_margin": 0.6107, "lr": 3.6306449262449543e-07, "elapsed_s": 377.1}
|
||||
{"step": 450, "loss": 0.479, "accuracy": 0.775, "reward_margin": 0.7882, "lr": 3.560248002721124e-07, "elapsed_s": 385.0}
|
||||
{"step": 460, "loss": 0.5344, "accuracy": 0.65, "reward_margin": 0.6866, "lr": 3.4890894358278937e-07, "elapsed_s": 392.9}
|
||||
{"step": 470, "loss": 0.534, "accuracy": 0.675, "reward_margin": 0.6002, "lr": 3.417265963885413e-07, "elapsed_s": 400.8}
|
||||
{"step": 480, "loss": 0.5583, "accuracy": 0.5625, "reward_margin": 0.7706, "lr": 3.3448752291358786e-07, "elapsed_s": 408.7}
|
||||
{"step": 490, "loss": 0.5412, "accuracy": 0.65, "reward_margin": 0.7216, "lr": 3.272015645001312e-07, "elapsed_s": 416.6}
|
||||
{"step": 500, "loss": 0.542, "accuracy": 0.6625, "reward_margin": 0.8168, "lr": 3.1987862622929316e-07, "elapsed_s": 424.5}
|
||||
{"step": 510, "loss": 0.6497, "accuracy": 0.5625, "reward_margin": 0.5525, "lr": 3.125286634554015e-07, "elapsed_s": 432.4}
|
||||
{"step": 520, "loss": 0.5132, "accuracy": 0.7125, "reward_margin": 0.712, "lr": 3.0516166827193075e-07, "elapsed_s": 440.4}
|
||||
{"step": 530, "loss": 0.5197, "accuracy": 0.6375, "reward_margin": 0.7373, "lr": 2.977876559274969e-07, "elapsed_s": 448.3}
|
||||
{"step": 540, "loss": 0.4932, "accuracy": 0.7125, "reward_margin": 0.8599, "lr": 2.9041665121037345e-07, "elapsed_s": 456.2}
|
||||
{"step": 550, "loss": 0.5508, "accuracy": 0.7, "reward_margin": 0.6573, "lr": 2.8305867482003896e-07, "elapsed_s": 464.1}
|
||||
{"step": 560, "loss": 0.5707, "accuracy": 0.6125, "reward_margin": 0.8398, "lr": 2.757237297442821e-07, "elapsed_s": 472.1}
|
||||
{"step": 570, "loss": 0.5083, "accuracy": 0.7375, "reward_margin": 0.8327, "lr": 2.6842178766038637e-07, "elapsed_s": 480.0}
|
||||
{"step": 580, "loss": 0.5289, "accuracy": 0.675, "reward_margin": 0.8025, "lr": 2.611627753788802e-07, "elapsed_s": 487.9}
|
||||
{"step": 590, "loss": 0.5546, "accuracy": 0.6625, "reward_margin": 0.5996, "lr": 2.5395656134828237e-07, "elapsed_s": 495.8}
|
||||
{"step": 600, "loss": 0.5843, "accuracy": 0.65, "reward_margin": 0.5105, "lr": 2.468129422391892e-07, "elapsed_s": 503.8}
|
||||
{"step": 610, "loss": 0.4989, "accuracy": 0.675, "reward_margin": 0.73, "lr": 2.3974162962594177e-07, "elapsed_s": 526.7}
|
||||
{"step": 620, "loss": 0.5662, "accuracy": 0.6875, "reward_margin": 0.5796, "lr": 2.3275223678398024e-07, "elapsed_s": 534.6}
|
||||
{"step": 630, "loss": 0.5067, "accuracy": 0.7, "reward_margin": 0.7825, "lr": 2.2585426562083175e-07, "elapsed_s": 542.5}
|
||||
{"step": 640, "loss": 0.5426, "accuracy": 0.6375, "reward_margin": 0.6038, "lr": 2.1905709375850164e-07, "elapsed_s": 550.4}
|
||||
{"step": 650, "loss": 0.5449, "accuracy": 0.6625, "reward_margin": 0.6423, "lr": 2.1236996178482677e-07, "elapsed_s": 558.3}
|
||||
{"step": 660, "loss": 0.4974, "accuracy": 0.725, "reward_margin": 0.9135, "lr": 2.058019606911242e-07, "elapsed_s": 566.2}
|
||||
{"step": 670, "loss": 0.4447, "accuracy": 0.7125, "reward_margin": 0.955, "lr": 1.9936201951321162e-07, "elapsed_s": 574.1}
|
||||
{"step": 680, "loss": 0.5137, "accuracy": 0.6875, "reward_margin": 0.8216, "lr": 1.9305889319260398e-07, "elapsed_s": 582.0}
|
||||
{"step": 690, "loss": 0.4977, "accuracy": 0.6875, "reward_margin": 0.9253, "lr": 1.869011506743846e-07, "elapsed_s": 589.9}
|
||||
{"step": 700, "loss": 0.5017, "accuracy": 0.675, "reward_margin": 0.8083, "lr": 1.8089716325793666e-07, "elapsed_s": 597.8}
|
||||
{"step": 710, "loss": 0.5133, "accuracy": 0.7125, "reward_margin": 0.872, "lr": 1.7505509321636675e-07, "elapsed_s": 605.7}
|
||||
{"step": 720, "loss": 0.5824, "accuracy": 0.7, "reward_margin": 0.7471, "lr": 1.6938288270009618e-07, "elapsed_s": 613.5}
|
||||
{"step": 730, "loss": 0.5589, "accuracy": 0.6, "reward_margin": 0.6109, "lr": 1.638882429397021e-07, "elapsed_s": 621.4}
|
||||
{"step": 740, "loss": 0.5899, "accuracy": 0.625, "reward_margin": 0.5541, "lr": 1.585786437626905e-07, "elapsed_s": 629.4}
|
||||
{"step": 750, "loss": 0.506, "accuracy": 0.7125, "reward_margin": 0.8685, "lr": 1.5346130343844857e-07, "elapsed_s": 637.3}
|
||||
{"step": 760, "loss": 0.571, "accuracy": 0.6875, "reward_margin": 0.6311, "lr": 1.485431788651856e-07, "elapsed_s": 645.3}
|
||||
{"step": 770, "loss": 0.5197, "accuracy": 0.6875, "reward_margin": 0.7544, "lr": 1.438309561122013e-07, "elapsed_s": 653.2}
|
||||
{"step": 780, "loss": 0.5176, "accuracy": 0.675, "reward_margin": 0.8367, "lr": 1.3933104133033846e-07, "elapsed_s": 661.2}
|
||||
{"step": 790, "loss": 0.6066, "accuracy": 0.6, "reward_margin": 0.5395, "lr": 1.3504955204297946e-07, "elapsed_s": 669.1}
|
||||
{"step": 800, "loss": 0.5212, "accuracy": 0.6875, "reward_margin": 0.7758, "lr": 1.3099230882942304e-07, "elapsed_s": 677.0}
|
||||
{"step": 810, "loss": 0.4471, "accuracy": 0.7625, "reward_margin": 1.008, "lr": 1.2716482741195066e-07, "elapsed_s": 698.5}
|
||||
{"step": 820, "loss": 0.5211, "accuracy": 0.675, "reward_margin": 0.7755, "lr": 1.235723111573371e-07, "elapsed_s": 708.0}
|
||||
{"step": 830, "loss": 0.5618, "accuracy": 0.7, "reward_margin": 0.6765, "lr": 1.2021964400300216e-07, "elapsed_s": 715.8}
|
||||
{"step": 840, "loss": 0.4606, "accuracy": 0.7, "reward_margin": 1.1556, "lr": 1.171113838174174e-07, "elapsed_s": 723.8}
|
||||
{"step": 850, "loss": 0.5364, "accuracy": 0.6625, "reward_margin": 0.7852, "lr": 1.1425175620379659e-07, "elapsed_s": 731.7}
|
||||
{"step": 860, "loss": 0.4544, "accuracy": 0.75, "reward_margin": 1.0015, "lr": 1.1164464875549158e-07, "elapsed_s": 739.6}
|
||||
{"step": 870, "loss": 0.4745, "accuracy": 0.7375, "reward_margin": 0.9336, "lr": 1.0929360577090547e-07, "elapsed_s": 747.4}
|
||||
{"step": 880, "loss": 0.5352, "accuracy": 0.7375, "reward_margin": 0.6781, "lr": 1.0720182343510565e-07, "elapsed_s": 755.4}
|
||||
{"step": 890, "loss": 0.5363, "accuracy": 0.75, "reward_margin": 0.8074, "lr": 1.0537214547468929e-07, "elapsed_s": 763.3}
|
||||
{"step": 900, "loss": 0.5094, "accuracy": 0.625, "reward_margin": 0.8066, "lr": 1.0380705929180662e-07, "elapsed_s": 771.2}
|
||||
{"step": 910, "loss": 0.4909, "accuracy": 0.7125, "reward_margin": 1.0196, "lr": 1.0250869258259928e-07, "elapsed_s": 779.0}
|
||||
{"step": 920, "loss": 0.4845, "accuracy": 0.725, "reward_margin": 0.8083, "lr": 1.0147881044464963e-07, "elapsed_s": 786.9}
|
||||
{"step": 930, "loss": 0.4922, "accuracy": 0.7375, "reward_margin": 0.8589, "lr": 1.0071881297737406e-07, "elapsed_s": 794.8}
|
||||
{"step": 940, "loss": 0.4391, "accuracy": 0.725, "reward_margin": 1.084, "lr": 1.0022973337862222e-07, "elapsed_s": 802.7}
|
||||
{"step": 950, "loss": 0.5421, "accuracy": 0.625, "reward_margin": 0.8424, "lr": 1.0001223654007014e-07, "elapsed_s": 810.6}
|
||||
324
training_logs/sft_log.jsonl
Normal file
324
training_logs/sft_log.jsonl
Normal file
@@ -0,0 +1,324 @@
|
||||
{"step": 10, "epoch": 1, "loss": 2.1095, "lr": 9.000000000000001e-07, "elapsed_s": 11.3}
|
||||
{"step": 20, "epoch": 1, "loss": 1.8581, "lr": 1.9e-06, "elapsed_s": 20.6}
|
||||
{"step": 30, "epoch": 1, "loss": 1.6729, "lr": 2.9e-06, "elapsed_s": 30.0}
|
||||
{"step": 40, "epoch": 1, "loss": 1.6325, "lr": 3.900000000000001e-06, "elapsed_s": 39.3}
|
||||
{"step": 50, "epoch": 1, "loss": 1.5802, "lr": 4.9000000000000005e-06, "elapsed_s": 48.7}
|
||||
{"step": 60, "epoch": 1, "loss": 1.5845, "lr": 5.9e-06, "elapsed_s": 58.1}
|
||||
{"step": 70, "epoch": 1, "loss": 1.5295, "lr": 6.900000000000001e-06, "elapsed_s": 67.4}
|
||||
{"step": 80, "epoch": 1, "loss": 1.5132, "lr": 7.9e-06, "elapsed_s": 76.8}
|
||||
{"step": 90, "epoch": 1, "loss": 1.5379, "lr": 8.900000000000001e-06, "elapsed_s": 86.2}
|
||||
{"step": 100, "epoch": 1, "loss": 1.4603, "lr": 9.9e-06, "elapsed_s": 95.5}
|
||||
{"step": 110, "epoch": 1, "loss": 1.468, "lr": 1.09e-05, "elapsed_s": 104.9}
|
||||
{"step": 120, "epoch": 1, "loss": 1.4928, "lr": 1.1900000000000001e-05, "elapsed_s": 114.3}
|
||||
{"step": 130, "epoch": 1, "loss": 1.4656, "lr": 1.2900000000000002e-05, "elapsed_s": 123.6}
|
||||
{"step": 140, "epoch": 1, "loss": 1.424, "lr": 1.3900000000000002e-05, "elapsed_s": 133.0}
|
||||
{"step": 150, "epoch": 1, "loss": 1.4656, "lr": 1.4900000000000003e-05, "elapsed_s": 142.4}
|
||||
{"step": 160, "epoch": 1, "loss": 1.3754, "lr": 1.59e-05, "elapsed_s": 151.7}
|
||||
{"step": 170, "epoch": 1, "loss": 1.4413, "lr": 1.69e-05, "elapsed_s": 161.1}
|
||||
{"step": 180, "epoch": 1, "loss": 1.4202, "lr": 1.79e-05, "elapsed_s": 170.4}
|
||||
{"step": 190, "epoch": 1, "loss": 1.4522, "lr": 1.8900000000000002e-05, "elapsed_s": 179.8}
|
||||
{"step": 200, "epoch": 1, "loss": 1.3809, "lr": 1.99e-05, "elapsed_s": 189.2}
|
||||
{"step": 210, "epoch": 1, "loss": 1.4069, "lr": 1.9999612774242138e-05, "elapsed_s": 198.5}
|
||||
{"step": 220, "epoch": 1, "loss": 1.4143, "lr": 1.9998274258845686e-05, "elapsed_s": 207.9}
|
||||
{"step": 230, "epoch": 1, "loss": 1.4054, "lr": 1.9995979815408517e-05, "elapsed_s": 217.3}
|
||||
{"step": 240, "epoch": 1, "loss": 1.4373, "lr": 1.9992729687679906e-05, "elapsed_s": 226.6}
|
||||
{"step": 250, "epoch": 1, "loss": 1.3924, "lr": 1.9988524220935858e-05, "elapsed_s": 236.0}
|
||||
{"step": 260, "epoch": 1, "loss": 1.4041, "lr": 1.9983363861942443e-05, "elapsed_s": 245.4}
|
||||
{"step": 270, "epoch": 1, "loss": 1.3964, "lr": 1.997724915890832e-05, "elapsed_s": 254.7}
|
||||
{"step": 280, "epoch": 1, "loss": 1.3783, "lr": 1.9970180761426505e-05, "elapsed_s": 264.1}
|
||||
{"step": 290, "epoch": 1, "loss": 1.3313, "lr": 1.996215942040535e-05, "elapsed_s": 273.4}
|
||||
{"step": 300, "epoch": 1, "loss": 1.3576, "lr": 1.995318598798879e-05, "elapsed_s": 282.8}
|
||||
{"step": 310, "epoch": 1, "loss": 1.3684, "lr": 1.9943261417465805e-05, "elapsed_s": 292.1}
|
||||
{"step": 320, "epoch": 1, "loss": 1.3309, "lr": 1.9932386763169144e-05, "elapsed_s": 301.5}
|
||||
{"step": 330, "epoch": 1, "loss": 1.3728, "lr": 1.9920563180363322e-05, "elapsed_s": 310.9}
|
||||
{"step": 340, "epoch": 1, "loss": 1.4069, "lr": 1.9907791925121902e-05, "elapsed_s": 320.2}
|
||||
{"step": 350, "epoch": 1, "loss": 1.3658, "lr": 1.9894074354194032e-05, "elapsed_s": 329.6}
|
||||
{"step": 360, "epoch": 1, "loss": 1.3596, "lr": 1.987941192486034e-05, "elapsed_s": 339.0}
|
||||
{"step": 370, "epoch": 1, "loss": 1.3535, "lr": 1.986380619477809e-05, "elapsed_s": 348.3}
|
||||
{"step": 380, "epoch": 1, "loss": 1.3938, "lr": 1.984725882181574e-05, "elapsed_s": 357.7}
|
||||
{"step": 390, "epoch": 1, "loss": 1.3241, "lr": 1.9829771563876787e-05, "elapsed_s": 367.0}
|
||||
{"step": 400, "epoch": 1, "loss": 1.3316, "lr": 1.9811346278713027e-05, "elapsed_s": 376.4}
|
||||
{"step": 410, "epoch": 1, "loss": 1.3579, "lr": 1.9791984923727213e-05, "elapsed_s": 385.8}
|
||||
{"step": 420, "epoch": 1, "loss": 1.3888, "lr": 1.9771689555765092e-05, "elapsed_s": 395.1}
|
||||
{"step": 430, "epoch": 1, "loss": 1.3647, "lr": 1.97504623308969e-05, "elapsed_s": 404.5}
|
||||
{"step": 440, "epoch": 1, "loss": 1.395, "lr": 1.9728305504188318e-05, "elapsed_s": 413.8}
|
||||
{"step": 450, "epoch": 1, "loss": 1.3856, "lr": 1.9705221429460907e-05, "elapsed_s": 423.2}
|
||||
{"step": 460, "epoch": 1, "loss": 1.2899, "lr": 1.9681212559042047e-05, "elapsed_s": 432.6}
|
||||
{"step": 470, "epoch": 1, "loss": 1.2937, "lr": 1.9656281443504413e-05, "elapsed_s": 441.9}
|
||||
{"step": 480, "epoch": 1, "loss": 1.3547, "lr": 1.963043073139502e-05, "elapsed_s": 451.3}
|
||||
{"step": 490, "epoch": 1, "loss": 1.341, "lr": 1.9603663168953853e-05, "elapsed_s": 460.7}
|
||||
{"step": 500, "epoch": 1, "loss": 1.3235, "lr": 1.9575981599822124e-05, "elapsed_s": 470.0}
|
||||
{"step": 510, "epoch": 1, "loss": 1.3391, "lr": 1.9547388964740182e-05, "elapsed_s": 490.7}
|
||||
{"step": 520, "epoch": 1, "loss": 1.3466, "lr": 1.951788830123509e-05, "elapsed_s": 500.1}
|
||||
{"step": 530, "epoch": 1, "loss": 1.3823, "lr": 1.9487482743297954e-05, "elapsed_s": 509.5}
|
||||
{"step": 540, "epoch": 1, "loss": 1.3413, "lr": 1.945617552105097e-05, "elapsed_s": 518.8}
|
||||
{"step": 550, "epoch": 1, "loss": 1.3231, "lr": 1.9423969960404283e-05, "elapsed_s": 528.2}
|
||||
{"step": 560, "epoch": 1, "loss": 1.3231, "lr": 1.939086948270265e-05, "elapsed_s": 537.5}
|
||||
{"step": 570, "epoch": 1, "loss": 1.3111, "lr": 1.9356877604361987e-05, "elapsed_s": 546.9}
|
||||
{"step": 580, "epoch": 1, "loss": 1.3378, "lr": 1.9321997936495792e-05, "elapsed_s": 556.2}
|
||||
{"step": 590, "epoch": 1, "loss": 1.3231, "lr": 1.9286234184531536e-05, "elapsed_s": 565.6}
|
||||
{"step": 600, "epoch": 1, "loss": 1.363, "lr": 1.924959014781699e-05, "elapsed_s": 575.0}
|
||||
{"step": 610, "epoch": 1, "loss": 1.2943, "lr": 1.9212069719216638e-05, "elapsed_s": 584.3}
|
||||
{"step": 620, "epoch": 1, "loss": 1.3628, "lr": 1.9173676884698097e-05, "elapsed_s": 593.7}
|
||||
{"step": 630, "epoch": 1, "loss": 1.3413, "lr": 1.9134415722908673e-05, "elapsed_s": 603.1}
|
||||
{"step": 640, "epoch": 1, "loss": 1.3496, "lr": 1.909429040474207e-05, "elapsed_s": 612.4}
|
||||
{"step": 650, "epoch": 1, "loss": 1.3215, "lr": 1.9053305192895297e-05, "elapsed_s": 621.8}
|
||||
{"step": 660, "epoch": 1, "loss": 1.3538, "lr": 1.901146444141583e-05, "elapsed_s": 631.2}
|
||||
{"step": 670, "epoch": 1, "loss": 1.3237, "lr": 1.8968772595239035e-05, "elapsed_s": 640.5}
|
||||
{"step": 680, "epoch": 1, "loss": 1.2838, "lr": 1.8925234189716e-05, "elapsed_s": 649.9}
|
||||
{"step": 690, "epoch": 1, "loss": 1.3398, "lr": 1.8880853850131694e-05, "elapsed_s": 659.3}
|
||||
{"step": 700, "epoch": 1, "loss": 1.3126, "lr": 1.883563629121361e-05, "elapsed_s": 668.6}
|
||||
{"step": 710, "epoch": 1, "loss": 1.368, "lr": 1.8789586316630903e-05, "elapsed_s": 678.0}
|
||||
{"step": 720, "epoch": 1, "loss": 1.2845, "lr": 1.874270881848407e-05, "elapsed_s": 687.4}
|
||||
{"step": 730, "epoch": 1, "loss": 1.3338, "lr": 1.8695008776785244e-05, "elapsed_s": 696.7}
|
||||
{"step": 740, "epoch": 1, "loss": 1.3256, "lr": 1.8646491258929136e-05, "elapsed_s": 706.1}
|
||||
{"step": 750, "epoch": 1, "loss": 1.379, "lr": 1.8597161419154707e-05, "elapsed_s": 715.4}
|
||||
{"step": 760, "epoch": 1, "loss": 1.3103, "lr": 1.8547024497997615e-05, "elapsed_s": 724.8}
|
||||
{"step": 770, "epoch": 1, "loss": 1.3564, "lr": 1.8496085821733482e-05, "elapsed_s": 734.2}
|
||||
{"step": 780, "epoch": 1, "loss": 1.3612, "lr": 1.844435080181205e-05, "elapsed_s": 743.5}
|
||||
{"step": 790, "epoch": 1, "loss": 1.3408, "lr": 1.839182493428233e-05, "elapsed_s": 752.9}
|
||||
{"step": 800, "epoch": 1, "loss": 1.3056, "lr": 1.8338513799208684e-05, "elapsed_s": 762.2}
|
||||
{"step": 810, "epoch": 1, "loss": 1.3447, "lr": 1.8284423060078082e-05, "elapsed_s": 771.6}
|
||||
{"step": 820, "epoch": 1, "loss": 1.3323, "lr": 1.8229558463198396e-05, "elapsed_s": 781.0}
|
||||
{"step": 830, "epoch": 1, "loss": 1.3024, "lr": 1.8173925837087975e-05, "elapsed_s": 790.3}
|
||||
{"step": 840, "epoch": 1, "loss": 1.3355, "lr": 1.8117531091856436e-05, "elapsed_s": 799.7}
|
||||
{"step": 850, "epoch": 1, "loss": 1.3239, "lr": 1.8060380218576828e-05, "elapsed_s": 809.0}
|
||||
{"step": 860, "epoch": 1, "loss": 1.3156, "lr": 1.8002479288649142e-05, "elapsed_s": 818.4}
|
||||
{"step": 870, "epoch": 1, "loss": 1.3339, "lr": 1.794383445315534e-05, "elapsed_s": 827.8}
|
||||
{"step": 880, "epoch": 1, "loss": 1.3303, "lr": 1.7884451942205902e-05, "elapsed_s": 837.1}
|
||||
{"step": 890, "epoch": 1, "loss": 1.2924, "lr": 1.782433806427795e-05, "elapsed_s": 846.5}
|
||||
{"step": 900, "epoch": 1, "loss": 1.3421, "lr": 1.7763499205545092e-05, "elapsed_s": 855.9}
|
||||
{"step": 910, "epoch": 1, "loss": 1.2697, "lr": 1.7701941829198966e-05, "elapsed_s": 865.2}
|
||||
{"step": 920, "epoch": 1, "loss": 1.3115, "lr": 1.7639672474762658e-05, "elapsed_s": 874.6}
|
||||
{"step": 930, "epoch": 1, "loss": 1.2875, "lr": 1.7576697757395946e-05, "elapsed_s": 884.0}
|
||||
{"step": 940, "epoch": 1, "loss": 1.2877, "lr": 1.7513024367192556e-05, "elapsed_s": 893.3}
|
||||
{"step": 950, "epoch": 1, "loss": 1.3348, "lr": 1.7448659068469446e-05, "elapsed_s": 902.7}
|
||||
{"step": 960, "epoch": 1, "loss": 1.3133, "lr": 1.7383608699048193e-05, "elapsed_s": 912.1}
|
||||
{"step": 970, "epoch": 1, "loss": 1.3313, "lr": 1.731788016952859e-05, "elapsed_s": 921.4}
|
||||
{"step": 980, "epoch": 1, "loss": 1.3189, "lr": 1.725148046255449e-05, "elapsed_s": 930.8}
|
||||
{"step": 990, "epoch": 1, "loss": 1.3097, "lr": 1.7184416632072002e-05, "elapsed_s": 940.2}
|
||||
{"step": 1000, "epoch": 1, "loss": 1.2952, "lr": 1.7116695802580155e-05, "elapsed_s": 949.5}
|
||||
{"step": 1010, "epoch": 1, "loss": 1.3551, "lr": 1.7048325168373977e-05, "elapsed_s": 974.2}
|
||||
{"step": 1020, "epoch": 1, "loss": 1.3245, "lr": 1.697931199278025e-05, "elapsed_s": 983.6}
|
||||
{"step": 1030, "epoch": 1, "loss": 1.302, "lr": 1.690966360738588e-05, "elapsed_s": 993.0}
|
||||
{"step": 1040, "epoch": 1, "loss": 1.2576, "lr": 1.6839387411259027e-05, "elapsed_s": 1002.3}
|
||||
{"step": 1050, "epoch": 1, "loss": 1.3409, "lr": 1.676849087016308e-05, "elapsed_s": 1011.7}
|
||||
{"step": 1060, "epoch": 1, "loss": 1.3087, "lr": 1.669698151576352e-05, "elapsed_s": 1021.0}
|
||||
{"step": 1070, "epoch": 1, "loss": 1.3103, "lr": 1.662486694482779e-05, "elapsed_s": 1030.4}
|
||||
{"step": 1080, "epoch": 1, "loss": 1.3006, "lr": 1.65521548184183e-05, "elapsed_s": 1039.7}
|
||||
{"step": 1090, "epoch": 1, "loss": 1.3006, "lr": 1.6478852861078486e-05, "elapsed_s": 1049.1}
|
||||
{"step": 1100, "epoch": 1, "loss": 1.3208, "lr": 1.6404968860012266e-05, "elapsed_s": 1058.5}
|
||||
{"step": 1110, "epoch": 1, "loss": 1.3296, "lr": 1.633051066425673e-05, "elapsed_s": 1067.9}
|
||||
{"step": 1120, "epoch": 1, "loss": 1.325, "lr": 1.6255486183848293e-05, "elapsed_s": 1077.2}
|
||||
{"step": 1130, "epoch": 1, "loss": 1.2892, "lr": 1.6179903388982417e-05, "elapsed_s": 1086.6}
|
||||
{"step": 1140, "epoch": 1, "loss": 1.3014, "lr": 1.6103770309166864e-05, "elapsed_s": 1095.9}
|
||||
{"step": 1150, "epoch": 1, "loss": 1.2787, "lr": 1.602709503236869e-05, "elapsed_s": 1105.3}
|
||||
{"step": 1160, "epoch": 1, "loss": 1.3036, "lr": 1.5949885704155044e-05, "elapsed_s": 1114.7}
|
||||
{"step": 1170, "epoch": 1, "loss": 1.2744, "lr": 1.587215052682779e-05, "elapsed_s": 1124.0}
|
||||
{"step": 1180, "epoch": 1, "loss": 1.2592, "lr": 1.5793897758552187e-05, "elapsed_s": 1133.3}
|
||||
{"step": 1190, "epoch": 1, "loss": 1.2926, "lr": 1.571513571247954e-05, "elapsed_s": 1142.7}
|
||||
{"step": 1200, "epoch": 1, "loss": 1.2902, "lr": 1.5635872755864088e-05, "elapsed_s": 1152.1}
|
||||
{"step": 1210, "epoch": 1, "loss": 1.3357, "lr": 1.5556117309174085e-05, "elapsed_s": 1161.4}
|
||||
{"step": 1220, "epoch": 1, "loss": 1.3249, "lr": 1.5475877845197284e-05, "elapsed_s": 1170.8}
|
||||
{"step": 1230, "epoch": 1, "loss": 1.3064, "lr": 1.5395162888140815e-05, "elapsed_s": 1180.2}
|
||||
{"step": 1240, "epoch": 1, "loss": 1.2959, "lr": 1.531398101272562e-05, "elapsed_s": 1189.5}
|
||||
{"step": 1250, "epoch": 1, "loss": 1.2756, "lr": 1.523234084327553e-05, "elapsed_s": 1198.9}
|
||||
{"step": 1260, "epoch": 1, "loss": 1.3085, "lr": 1.5150251052801055e-05, "elapsed_s": 1208.3}
|
||||
{"step": 1270, "epoch": 1, "loss": 1.2786, "lr": 1.5067720362078014e-05, "elapsed_s": 1217.6}
|
||||
{"step": 1280, "epoch": 1, "loss": 1.2962, "lr": 1.498475753872109e-05, "elapsed_s": 1227.0}
|
||||
{"step": 1290, "epoch": 1, "loss": 1.3318, "lr": 1.4901371396252392e-05, "elapsed_s": 1236.4}
|
||||
{"step": 1300, "epoch": 1, "loss": 1.3128, "lr": 1.4817570793165175e-05, "elapsed_s": 1245.7}
|
||||
{"step": 1310, "epoch": 1, "loss": 1.2791, "lr": 1.473336463198275e-05, "elapsed_s": 1255.1}
|
||||
{"step": 1320, "epoch": 1, "loss": 1.2507, "lr": 1.4648761858312718e-05, "elapsed_s": 1264.5}
|
||||
{"step": 1330, "epoch": 1, "loss": 1.3079, "lr": 1.456377145989666e-05, "elapsed_s": 1273.8}
|
||||
{"step": 1340, "epoch": 1, "loss": 1.2738, "lr": 1.4478402465655313e-05, "elapsed_s": 1283.2}
|
||||
{"step": 1350, "epoch": 1, "loss": 1.2544, "lr": 1.4392663944729386e-05, "elapsed_s": 1292.5}
|
||||
{"step": 1360, "epoch": 1, "loss": 1.298, "lr": 1.4306565005516104e-05, "elapsed_s": 1301.9}
|
||||
{"step": 1370, "epoch": 1, "loss": 1.272, "lr": 1.4220114794701593e-05, "elapsed_s": 1311.2}
|
||||
{"step": 1380, "epoch": 1, "loss": 1.2523, "lr": 1.4133322496289168e-05, "elapsed_s": 1320.6}
|
||||
{"step": 1390, "epoch": 1, "loss": 1.258, "lr": 1.4046197330623684e-05, "elapsed_s": 1330.0}
|
||||
{"step": 1400, "epoch": 1, "loss": 1.2995, "lr": 1.3958748553412014e-05, "elapsed_s": 1339.3}
|
||||
{"step": 1410, "epoch": 1, "loss": 1.312, "lr": 1.3870985454739776e-05, "elapsed_s": 1348.7}
|
||||
{"step": 1420, "epoch": 1, "loss": 1.2912, "lr": 1.37829173580844e-05, "elapsed_s": 1358.1}
|
||||
{"step": 1430, "epoch": 1, "loss": 1.2625, "lr": 1.369455361932465e-05, "elapsed_s": 1367.4}
|
||||
{"step": 1440, "epoch": 1, "loss": 1.2606, "lr": 1.3605903625746721e-05, "elapsed_s": 1376.8}
|
||||
{"step": 1450, "epoch": 1, "loss": 1.2693, "lr": 1.3516976795046961e-05, "elapsed_s": 1386.1}
|
||||
{"step": 1460, "epoch": 1, "loss": 1.2725, "lr": 1.3427782574331403e-05, "elapsed_s": 1395.5}
|
||||
{"step": 1470, "epoch": 1, "loss": 1.2454, "lr": 1.3338330439112152e-05, "elapsed_s": 1404.9}
|
||||
{"step": 1480, "epoch": 1, "loss": 1.2741, "lr": 1.3248629892300753e-05, "elapsed_s": 1414.2}
|
||||
{"step": 1490, "epoch": 1, "loss": 1.2666, "lr": 1.3158690463198665e-05, "elapsed_s": 1423.6}
|
||||
{"step": 1500, "epoch": 1, "loss": 1.3066, "lr": 1.3068521706484893e-05, "elapsed_s": 1433.0}
|
||||
{"step": 1510, "epoch": 1, "loss": 1.2859, "lr": 1.2978133201200992e-05, "elapsed_s": 1454.2}
|
||||
{"step": 1520, "epoch": 1, "loss": 1.3685, "lr": 1.2887534549733395e-05, "elapsed_s": 1463.5}
|
||||
{"step": 1530, "epoch": 1, "loss": 1.2583, "lr": 1.279673537679335e-05, "elapsed_s": 1472.9}
|
||||
{"step": 1540, "epoch": 1, "loss": 1.3214, "lr": 1.2705745328394408e-05, "elapsed_s": 1482.2}
|
||||
{"step": 1550, "epoch": 1, "loss": 1.2376, "lr": 1.2614574070827704e-05, "elapsed_s": 1491.6}
|
||||
{"step": 1560, "epoch": 1, "loss": 1.2253, "lr": 1.252323128963506e-05, "elapsed_s": 1501.0}
|
||||
{"step": 1570, "epoch": 1, "loss": 1.2817, "lr": 1.2431726688580025e-05, "elapsed_s": 1510.3}
|
||||
{"step": 1580, "epoch": 1, "loss": 1.2539, "lr": 1.234006998861704e-05, "elapsed_s": 1519.7}
|
||||
{"step": 1590, "epoch": 1, "loss": 1.2711, "lr": 1.224827092685869e-05, "elapsed_s": 1529.1}
|
||||
{"step": 1600, "epoch": 1, "loss": 1.3262, "lr": 1.2156339255541325e-05, "elapsed_s": 1538.4}
|
||||
{"step": 1610, "epoch": 1, "loss": 1.2492, "lr": 1.2064284740989003e-05, "elapsed_s": 1547.8}
|
||||
{"step": 1620, "epoch": 1, "loss": 1.2889, "lr": 1.1972117162575997e-05, "elapsed_s": 1557.2}
|
||||
{"step": 1630, "epoch": 2, "loss": 1.2466, "lr": 1.1879846311687867e-05, "elapsed_s": 1568.2}
|
||||
{"step": 1640, "epoch": 2, "loss": 1.2516, "lr": 1.1787481990681277e-05, "elapsed_s": 1577.5}
|
||||
{"step": 1650, "epoch": 2, "loss": 1.2368, "lr": 1.1695034011842666e-05, "elapsed_s": 1586.9}
|
||||
{"step": 1660, "epoch": 2, "loss": 1.2125, "lr": 1.1602512196345819e-05, "elapsed_s": 1596.2}
|
||||
{"step": 1670, "epoch": 2, "loss": 1.2006, "lr": 1.150992637320853e-05, "elapsed_s": 1605.6}
|
||||
{"step": 1680, "epoch": 2, "loss": 1.2363, "lr": 1.1417286378248416e-05, "elapsed_s": 1615.0}
|
||||
{"step": 1690, "epoch": 2, "loss": 1.2311, "lr": 1.1324602053038026e-05, "elapsed_s": 1624.4}
|
||||
{"step": 1700, "epoch": 2, "loss": 1.2346, "lr": 1.1231883243859305e-05, "elapsed_s": 1633.7}
|
||||
{"step": 1710, "epoch": 2, "loss": 1.2462, "lr": 1.113913980065759e-05, "elapsed_s": 1643.1}
|
||||
{"step": 1720, "epoch": 2, "loss": 1.2654, "lr": 1.104638157599521e-05, "elapsed_s": 1652.4}
|
||||
{"step": 1730, "epoch": 2, "loss": 1.2545, "lr": 1.0953618424004792e-05, "elapsed_s": 1661.8}
|
||||
{"step": 1740, "epoch": 2, "loss": 1.2452, "lr": 1.0860860199342411e-05, "elapsed_s": 1671.2}
|
||||
{"step": 1750, "epoch": 2, "loss": 1.2416, "lr": 1.0768116756140696e-05, "elapsed_s": 1680.5}
|
||||
{"step": 1760, "epoch": 2, "loss": 1.2029, "lr": 1.0675397946961972e-05, "elapsed_s": 1689.9}
|
||||
{"step": 1770, "epoch": 2, "loss": 1.2086, "lr": 1.0582713621751584e-05, "elapsed_s": 1699.2}
|
||||
{"step": 1780, "epoch": 2, "loss": 1.2165, "lr": 1.049007362679147e-05, "elapsed_s": 1708.6}
|
||||
{"step": 1790, "epoch": 2, "loss": 1.1799, "lr": 1.039748780365418e-05, "elapsed_s": 1718.0}
|
||||
{"step": 1800, "epoch": 2, "loss": 1.2598, "lr": 1.0304965988157335e-05, "elapsed_s": 1727.3}
|
||||
{"step": 1810, "epoch": 2, "loss": 1.2294, "lr": 1.0212518009318725e-05, "elapsed_s": 1736.7}
|
||||
{"step": 1820, "epoch": 2, "loss": 1.2215, "lr": 1.0120153688312134e-05, "elapsed_s": 1746.1}
|
||||
{"step": 1830, "epoch": 2, "loss": 1.2428, "lr": 1.0027882837424002e-05, "elapsed_s": 1755.4}
|
||||
{"step": 1840, "epoch": 2, "loss": 1.1979, "lr": 9.935715259010998e-06, "elapsed_s": 1764.8}
|
||||
{"step": 1850, "epoch": 2, "loss": 1.2127, "lr": 9.843660744458676e-06, "elapsed_s": 1774.2}
|
||||
{"step": 1860, "epoch": 2, "loss": 1.2166, "lr": 9.751729073141308e-06, "elapsed_s": 1783.6}
|
||||
{"step": 1870, "epoch": 2, "loss": 1.2319, "lr": 9.659930011382963e-06, "elapsed_s": 1793.0}
|
||||
{"step": 1880, "epoch": 2, "loss": 1.1878, "lr": 9.568273311419975e-06, "elapsed_s": 1802.3}
|
||||
{"step": 1890, "epoch": 2, "loss": 1.2353, "lr": 9.476768710364943e-06, "elapsed_s": 1811.7}
|
||||
{"step": 1900, "epoch": 2, "loss": 1.1796, "lr": 9.385425929172294e-06, "elapsed_s": 1821.1}
|
||||
{"step": 1910, "epoch": 2, "loss": 1.1965, "lr": 9.294254671605594e-06, "elapsed_s": 1830.4}
|
||||
{"step": 1920, "epoch": 2, "loss": 1.2011, "lr": 9.20326462320665e-06, "elapsed_s": 1839.8}
|
||||
{"step": 1930, "epoch": 2, "loss": 1.207, "lr": 9.112465450266603e-06, "elapsed_s": 1849.2}
|
||||
{"step": 1940, "epoch": 2, "loss": 1.2013, "lr": 9.021866798799013e-06, "elapsed_s": 1858.6}
|
||||
{"step": 1950, "epoch": 2, "loss": 1.2454, "lr": 8.931478293515108e-06, "elapsed_s": 1867.9}
|
||||
{"step": 1960, "epoch": 2, "loss": 1.2415, "lr": 8.841309536801337e-06, "elapsed_s": 1877.3}
|
||||
{"step": 1970, "epoch": 2, "loss": 1.2149, "lr": 8.751370107699245e-06, "elapsed_s": 1886.7}
|
||||
{"step": 1980, "epoch": 2, "loss": 1.2339, "lr": 8.66166956088785e-06, "elapsed_s": 1896.1}
|
||||
{"step": 1990, "epoch": 2, "loss": 1.2489, "lr": 8.572217425668599e-06, "elapsed_s": 1905.4}
|
||||
{"step": 2000, "epoch": 2, "loss": 1.2701, "lr": 8.48302320495304e-06, "elapsed_s": 1914.8}
|
||||
{"step": 2010, "epoch": 2, "loss": 1.2353, "lr": 8.394096374253282e-06, "elapsed_s": 1935.5}
|
||||
{"step": 2020, "epoch": 2, "loss": 1.2168, "lr": 8.30544638067535e-06, "elapsed_s": 1944.9}
|
||||
{"step": 2030, "epoch": 2, "loss": 1.2411, "lr": 8.217082641915602e-06, "elapsed_s": 1954.3}
|
||||
{"step": 2040, "epoch": 2, "loss": 1.2011, "lr": 8.129014545260226e-06, "elapsed_s": 1963.6}
|
||||
{"step": 2050, "epoch": 2, "loss": 1.2018, "lr": 8.041251446587989e-06, "elapsed_s": 1973.0}
|
||||
{"step": 2060, "epoch": 2, "loss": 1.2175, "lr": 7.953802669376318e-06, "elapsed_s": 1982.3}
|
||||
{"step": 2070, "epoch": 2, "loss": 1.2739, "lr": 7.866677503710832e-06, "elapsed_s": 1991.7}
|
||||
{"step": 2080, "epoch": 2, "loss": 1.2066, "lr": 7.779885205298407e-06, "elapsed_s": 2001.1}
|
||||
{"step": 2090, "epoch": 2, "loss": 1.2177, "lr": 7.693434994483897e-06, "elapsed_s": 2010.5}
|
||||
{"step": 2100, "epoch": 2, "loss": 1.2246, "lr": 7.607336055270615e-06, "elapsed_s": 2019.8}
|
||||
{"step": 2110, "epoch": 2, "loss": 1.1861, "lr": 7.521597534344686e-06, "elapsed_s": 2029.2}
|
||||
{"step": 2120, "epoch": 2, "loss": 1.2146, "lr": 7.436228540103342e-06, "elapsed_s": 2038.6}
|
||||
{"step": 2130, "epoch": 2, "loss": 1.2034, "lr": 7.351238141687283e-06, "elapsed_s": 2048.0}
|
||||
{"step": 2140, "epoch": 2, "loss": 1.2174, "lr": 7.266635368017252e-06, "elapsed_s": 2057.3}
|
||||
{"step": 2150, "epoch": 2, "loss": 1.1906, "lr": 7.182429206834824e-06, "elapsed_s": 2066.7}
|
||||
{"step": 2160, "epoch": 2, "loss": 1.2253, "lr": 7.0986286037476105e-06, "elapsed_s": 2076.0}
|
||||
{"step": 2170, "epoch": 2, "loss": 1.2498, "lr": 7.0152424612789135e-06, "elapsed_s": 2085.4}
|
||||
{"step": 2180, "epoch": 2, "loss": 1.2238, "lr": 6.932279637921987e-06, "elapsed_s": 2094.8}
|
||||
{"step": 2190, "epoch": 2, "loss": 1.1924, "lr": 6.8497489471989465e-06, "elapsed_s": 2104.1}
|
||||
{"step": 2200, "epoch": 2, "loss": 1.2275, "lr": 6.767659156724471e-06, "elapsed_s": 2113.5}
|
||||
{"step": 2210, "epoch": 2, "loss": 1.2431, "lr": 6.686018987274381e-06, "elapsed_s": 2122.9}
|
||||
{"step": 2220, "epoch": 2, "loss": 1.2693, "lr": 6.604837111859187e-06, "elapsed_s": 2132.3}
|
||||
{"step": 2230, "epoch": 2, "loss": 1.2064, "lr": 6.524122154802721e-06, "elapsed_s": 2141.6}
|
||||
{"step": 2240, "epoch": 2, "loss": 1.2139, "lr": 6.44388269082592e-06, "elapsed_s": 2151.0}
|
||||
{"step": 2250, "epoch": 2, "loss": 1.2408, "lr": 6.3641272441359165e-06, "elapsed_s": 2160.4}
|
||||
{"step": 2260, "epoch": 2, "loss": 1.2177, "lr": 6.28486428752046e-06, "elapsed_s": 2169.7}
|
||||
{"step": 2270, "epoch": 2, "loss": 1.2337, "lr": 6.206102241447814e-06, "elapsed_s": 2179.1}
|
||||
{"step": 2280, "epoch": 2, "loss": 1.2438, "lr": 6.127849473172208e-06, "elapsed_s": 2188.5}
|
||||
{"step": 2290, "epoch": 2, "loss": 1.2196, "lr": 6.050114295844959e-06, "elapsed_s": 2197.8}
|
||||
{"step": 2300, "epoch": 2, "loss": 1.2225, "lr": 5.972904967631312e-06, "elapsed_s": 2207.2}
|
||||
{"step": 2310, "epoch": 2, "loss": 1.1983, "lr": 5.8962296908331385e-06, "elapsed_s": 2216.6}
|
||||
{"step": 2320, "epoch": 2, "loss": 1.19, "lr": 5.820096611017584e-06, "elapsed_s": 2226.0}
|
||||
{"step": 2330, "epoch": 2, "loss": 1.2123, "lr": 5.744513816151708e-06, "elapsed_s": 2235.3}
|
||||
{"step": 2340, "epoch": 2, "loss": 1.2204, "lr": 5.6694893357432744e-06, "elapsed_s": 2244.7}
|
||||
{"step": 2350, "epoch": 2, "loss": 1.2133, "lr": 5.595031139987734e-06, "elapsed_s": 2254.1}
|
||||
{"step": 2360, "epoch": 2, "loss": 1.2689, "lr": 5.5211471389215135e-06, "elapsed_s": 2263.5}
|
||||
{"step": 2370, "epoch": 2, "loss": 1.2408, "lr": 5.447845181581706e-06, "elapsed_s": 2272.9}
|
||||
{"step": 2380, "epoch": 2, "loss": 1.1791, "lr": 5.37513305517221e-06, "elapsed_s": 2282.2}
|
||||
{"step": 2390, "epoch": 2, "loss": 1.2051, "lr": 5.303018484236485e-06, "elapsed_s": 2291.6}
|
||||
{"step": 2400, "epoch": 2, "loss": 1.227, "lr": 5.23150912983692e-06, "elapsed_s": 2301.0}
|
||||
{"step": 2410, "epoch": 2, "loss": 1.2673, "lr": 5.160612588740973e-06, "elapsed_s": 2310.4}
|
||||
{"step": 2420, "epoch": 2, "loss": 1.2131, "lr": 5.090336392614121e-06, "elapsed_s": 2319.7}
|
||||
{"step": 2430, "epoch": 2, "loss": 1.214, "lr": 5.020688007219751e-06, "elapsed_s": 2329.1}
|
||||
{"step": 2440, "epoch": 2, "loss": 1.2311, "lr": 4.951674831626027e-06, "elapsed_s": 2338.5}
|
||||
{"step": 2450, "epoch": 2, "loss": 1.2143, "lr": 4.883304197419848e-06, "elapsed_s": 2347.8}
|
||||
{"step": 2460, "epoch": 2, "loss": 1.2233, "lr": 4.815583367927997e-06, "elapsed_s": 2357.2}
|
||||
{"step": 2470, "epoch": 2, "loss": 1.2608, "lr": 4.748519537445514e-06, "elapsed_s": 2366.6}
|
||||
{"step": 2480, "epoch": 2, "loss": 1.2123, "lr": 4.682119830471411e-06, "elapsed_s": 2375.9}
|
||||
{"step": 2490, "epoch": 2, "loss": 1.2159, "lr": 4.616391300951807e-06, "elapsed_s": 2385.3}
|
||||
{"step": 2500, "epoch": 2, "loss": 1.2447, "lr": 4.551340931530556e-06, "elapsed_s": 2394.7}
|
||||
{"step": 2510, "epoch": 2, "loss": 1.1837, "lr": 4.486975632807449e-06, "elapsed_s": 2415.2}
|
||||
{"step": 2520, "epoch": 2, "loss": 1.2267, "lr": 4.423302242604059e-06, "elapsed_s": 2424.6}
|
||||
{"step": 2530, "epoch": 2, "loss": 1.1857, "lr": 4.360327525237345e-06, "elapsed_s": 2433.9}
|
||||
{"step": 2540, "epoch": 2, "loss": 1.2175, "lr": 4.298058170801035e-06, "elapsed_s": 2443.3}
|
||||
{"step": 2550, "epoch": 2, "loss": 1.215, "lr": 4.236500794454911e-06, "elapsed_s": 2452.7}
|
||||
{"step": 2560, "epoch": 2, "loss": 1.2179, "lr": 4.17566193572205e-06, "elapsed_s": 2462.1}
|
||||
{"step": 2570, "epoch": 2, "loss": 1.2297, "lr": 4.1155480577940984e-06, "elapsed_s": 2471.4}
|
||||
{"step": 2580, "epoch": 2, "loss": 1.245, "lr": 4.056165546844662e-06, "elapsed_s": 2480.8}
|
||||
{"step": 2590, "epoch": 2, "loss": 1.1813, "lr": 3.997520711350863e-06, "elapsed_s": 2490.2}
|
||||
{"step": 2600, "epoch": 2, "loss": 1.1759, "lr": 3.939619781423175e-06, "elapsed_s": 2499.6}
|
||||
{"step": 2610, "epoch": 2, "loss": 1.1998, "lr": 3.882468908143565e-06, "elapsed_s": 2508.9}
|
||||
{"step": 2620, "epoch": 2, "loss": 1.1735, "lr": 3.826074162912028e-06, "elapsed_s": 2518.3}
|
||||
{"step": 2630, "epoch": 2, "loss": 1.2292, "lr": 3.770441536801607e-06, "elapsed_s": 2527.7}
|
||||
{"step": 2640, "epoch": 2, "loss": 1.178, "lr": 3.71557693992192e-06, "elapsed_s": 2537.1}
|
||||
{"step": 2650, "epoch": 2, "loss": 1.2265, "lr": 3.6614862007913155e-06, "elapsed_s": 2546.4}
|
||||
{"step": 2660, "epoch": 2, "loss": 1.2151, "lr": 3.608175065717676e-06, "elapsed_s": 2555.8}
|
||||
{"step": 2670, "epoch": 2, "loss": 1.1929, "lr": 3.5556491981879526e-06, "elapsed_s": 2565.1}
|
||||
{"step": 2680, "epoch": 2, "loss": 1.1812, "lr": 3.503914178266523e-06, "elapsed_s": 2574.5}
|
||||
{"step": 2690, "epoch": 2, "loss": 1.2171, "lr": 3.452975502002387e-06, "elapsed_s": 2583.9}
|
||||
{"step": 2700, "epoch": 2, "loss": 1.2061, "lr": 3.402838580845295e-06, "elapsed_s": 2593.3}
|
||||
{"step": 2710, "epoch": 2, "loss": 1.2286, "lr": 3.353508741070866e-06, "elapsed_s": 2602.6}
|
||||
{"step": 2720, "epoch": 2, "loss": 1.2086, "lr": 3.3049912232147573e-06, "elapsed_s": 2612.0}
|
||||
{"step": 2730, "epoch": 2, "loss": 1.1707, "lr": 3.257291181515933e-06, "elapsed_s": 2621.4}
|
||||
{"step": 2740, "epoch": 2, "loss": 1.1635, "lr": 3.210413683369101e-06, "elapsed_s": 2630.8}
|
||||
{"step": 2750, "epoch": 2, "loss": 1.1937, "lr": 3.164363708786394e-06, "elapsed_s": 2640.1}
|
||||
{"step": 2760, "epoch": 2, "loss": 1.2802, "lr": 3.119146149868308e-06, "elapsed_s": 2649.5}
|
||||
{"step": 2770, "epoch": 2, "loss": 1.2147, "lr": 3.0747658102840005e-06, "elapsed_s": 2658.9}
|
||||
{"step": 2780, "epoch": 2, "loss": 1.1958, "lr": 3.0312274047609644e-06, "elapsed_s": 2668.3}
|
||||
{"step": 2790, "epoch": 2, "loss": 1.1907, "lr": 2.9885355585841722e-06, "elapsed_s": 2677.6}
|
||||
{"step": 2800, "epoch": 2, "loss": 1.197, "lr": 2.9466948071047043e-06, "elapsed_s": 2687.0}
|
||||
{"step": 2810, "epoch": 2, "loss": 1.1978, "lr": 2.9057095952579336e-06, "elapsed_s": 2696.4}
|
||||
{"step": 2820, "epoch": 2, "loss": 1.2169, "lr": 2.8655842770913302e-06, "elapsed_s": 2705.8}
|
||||
{"step": 2830, "epoch": 2, "loss": 1.2068, "lr": 2.826323115301905e-06, "elapsed_s": 2715.1}
|
||||
{"step": 2840, "epoch": 2, "loss": 1.2295, "lr": 2.7879302807833625e-06, "elapsed_s": 2724.5}
|
||||
{"step": 2850, "epoch": 2, "loss": 1.2024, "lr": 2.7504098521830113e-06, "elapsed_s": 2733.9}
|
||||
{"step": 2860, "epoch": 2, "loss": 1.1798, "lr": 2.713765815468467e-06, "elapsed_s": 2743.3}
|
||||
{"step": 2870, "epoch": 2, "loss": 1.2136, "lr": 2.67800206350421e-06, "elapsed_s": 2752.6}
|
||||
{"step": 2880, "epoch": 2, "loss": 1.2771, "lr": 2.6431223956380163e-06, "elapsed_s": 2762.0}
|
||||
{"step": 2890, "epoch": 2, "loss": 1.2585, "lr": 2.6091305172973524e-06, "elapsed_s": 2771.4}
|
||||
{"step": 2900, "epoch": 2, "loss": 1.2018, "lr": 2.5760300395957185e-06, "elapsed_s": 2780.7}
|
||||
{"step": 2910, "epoch": 2, "loss": 1.2763, "lr": 2.543824478949031e-06, "elapsed_s": 2790.1}
|
||||
{"step": 2920, "epoch": 2, "loss": 1.1883, "lr": 2.5125172567020476e-06, "elapsed_s": 2799.5}
|
||||
{"step": 2930, "epoch": 2, "loss": 1.1856, "lr": 2.4821116987649116e-06, "elapsed_s": 2808.9}
|
||||
{"step": 2940, "epoch": 2, "loss": 1.2025, "lr": 2.4526110352598214e-06, "elapsed_s": 2818.2}
|
||||
{"step": 2950, "epoch": 2, "loss": 1.1775, "lr": 2.424018400177877e-06, "elapsed_s": 2827.6}
|
||||
{"step": 2960, "epoch": 2, "loss": 1.2086, "lr": 2.3963368310461503e-06, "elapsed_s": 2837.0}
|
||||
{"step": 2970, "epoch": 2, "loss": 1.2529, "lr": 2.3695692686049823e-06, "elapsed_s": 2846.3}
|
||||
{"step": 2980, "epoch": 2, "loss": 1.2575, "lr": 2.3437185564955893e-06, "elapsed_s": 2855.7}
|
||||
{"step": 2990, "epoch": 2, "loss": 1.26, "lr": 2.3187874409579548e-06, "elapsed_s": 2865.1}
|
||||
{"step": 3000, "epoch": 2, "loss": 1.2286, "lr": 2.294778570539094e-06, "elapsed_s": 2874.4}
|
||||
{"step": 3010, "epoch": 2, "loss": 1.2415, "lr": 2.2716944958116844e-06, "elapsed_s": 2895.6}
|
||||
{"step": 3020, "epoch": 2, "loss": 1.2682, "lr": 2.2495376691031034e-06, "elapsed_s": 2905.0}
|
||||
{"step": 3030, "epoch": 2, "loss": 1.2742, "lr": 2.2283104442349107e-06, "elapsed_s": 2914.4}
|
||||
{"step": 3040, "epoch": 2, "loss": 1.2586, "lr": 2.208015076272787e-06, "elapsed_s": 2923.7}
|
||||
{"step": 3050, "epoch": 2, "loss": 1.2062, "lr": 2.1886537212869744e-06, "elapsed_s": 2933.1}
|
||||
{"step": 3060, "epoch": 2, "loss": 1.258, "lr": 2.170228436123217e-06, "elapsed_s": 2942.5}
|
||||
{"step": 3070, "epoch": 2, "loss": 1.189, "lr": 2.1527411781842617e-06, "elapsed_s": 2951.9}
|
||||
{"step": 3080, "epoch": 2, "loss": 1.2488, "lr": 2.1361938052219115e-06, "elapsed_s": 2961.2}
|
||||
{"step": 3090, "epoch": 2, "loss": 1.2498, "lr": 2.1205880751396636e-06, "elapsed_s": 2970.6}
|
||||
{"step": 3100, "epoch": 2, "loss": 1.1549, "lr": 2.105925645805969e-06, "elapsed_s": 2979.9}
|
||||
{"step": 3110, "epoch": 2, "loss": 1.2015, "lr": 2.0922080748780995e-06, "elapsed_s": 2989.3}
|
||||
{"step": 3120, "epoch": 2, "loss": 1.2792, "lr": 2.079436819636678e-06, "elapsed_s": 2998.7}
|
||||
{"step": 3130, "epoch": 2, "loss": 1.2746, "lr": 2.0676132368308576e-06, "elapsed_s": 3008.0}
|
||||
{"step": 3140, "epoch": 2, "loss": 1.2482, "lr": 2.056738582534195e-06, "elapsed_s": 3017.4}
|
||||
{"step": 3150, "epoch": 2, "loss": 1.2272, "lr": 2.04681401201121e-06, "elapsed_s": 3026.8}
|
||||
{"step": 3160, "epoch": 2, "loss": 1.2122, "lr": 2.037840579594651e-06, "elapsed_s": 3036.1}
|
||||
{"step": 3170, "epoch": 2, "loss": 1.175, "lr": 2.0298192385734965e-06, "elapsed_s": 3045.5}
|
||||
{"step": 3180, "epoch": 2, "loss": 1.1944, "lr": 2.0227508410916793e-06, "elapsed_s": 3054.8}
|
||||
{"step": 3190, "epoch": 2, "loss": 1.1978, "lr": 2.016636138057557e-06, "elapsed_s": 3064.2}
|
||||
{"step": 3200, "epoch": 2, "loss": 1.2078, "lr": 2.011475779064144e-06, "elapsed_s": 3073.6}
|
||||
{"step": 3210, "epoch": 2, "loss": 1.2308, "lr": 2.0072703123200985e-06, "elapsed_s": 3082.9}
|
||||
{"step": 3220, "epoch": 2, "loss": 1.2213, "lr": 2.0040201845914854e-06, "elapsed_s": 3092.3}
|
||||
{"step": 3230, "epoch": 2, "loss": 1.1818, "lr": 2.001725741154316e-06, "elapsed_s": 3101.7}
|
||||
{"step": 3240, "epoch": 2, "loss": 1.2036, "lr": 2.0003872257578625e-06, "elapsed_s": 3111.1}
|
||||
1966
training_logs/train_log.jsonl
Normal file
1966
training_logs/train_log.jsonl
Normal file
File diff suppressed because it is too large
Load Diff
3
training_report.pdf
Normal file
3
training_report.pdf
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d1383cb9c7d854ee150b21f81c8f86d7b7201b7ec00df7a6cbcf6db896a3dde2
|
||||
size 124746
|
||||
Reference in New Issue
Block a user