Files
transformer-1b-chat/training_code/train_dpo.py
ModelHub XC 070e055bf5 初始化项目,由ModelHub XC社区提供模型
Model: divakar-yadav/transformer-1b-chat
Source: Original Platform
2026-06-20 17:27:58 +08:00

328 lines
11 KiB
Python

"""
DPO (Direct Preference Optimization) training for the 1B Transformer.
Takes the SFT model and aligns it with human preferences using
UltraFeedback preference pairs.
DPO Loss:
L = -log sigma(beta * (log pi(yw|x)/pi_ref(yw|x) - log pi(yl|x)/pi_ref(yl|x)))
Launch: torchrun --nproc_per_node=8 train_dpo.py
"""
import os
import sys
import math
import time
import json
import datetime
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from model.config import ModelConfig
from model.transformer import Transformer
from model.data import get_tokenizer
from model.dpo_data import DPODataset, dpo_collate_fn
# === Config ===
SFT_CHECKPOINT = "/jfs/deepak-kumar/checkpoints_sft/sft_final.pt"
DPO_CHECKPOINT_DIR = "/jfs/deepak-kumar/checkpoints_dpo"
LOG_DIR = "/home/jovyan/training/logs"
DATA_CACHE = "/jfs/deepak-kumar/data"
NUM_EPOCHS = 1
BATCH_SIZE_PER_GPU = 2
GRADIENT_ACCUMULATION = 4 # effective batch = 2 * 8 * 4 = 64
MAX_SEQ_LEN = 1024
LEARNING_RATE = 5e-7 # very low LR for DPO
MIN_LR = 1e-7
WARMUP_STEPS = 100
WEIGHT_DECAY = 0.01
GRAD_CLIP = 1.0
BETA = 0.1 # DPO temperature
LOG_INTERVAL = 10
SAVE_INTERVAL = 200
def get_cosine_lr(step, warmup_steps, total_steps, max_lr, min_lr):
if step < warmup_steps:
return max_lr * step / max(warmup_steps, 1)
progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))
def get_per_token_logps(model, input_ids, prompt_lens):
"""
Compute sum of log probabilities for response tokens only.
input_ids: [B, S] full sequence (prompt + response)
prompt_lens: [B] where response starts
Returns: [B] sum of log probs over response tokens
"""
# Clone input to avoid inplace issues with shared RoPE buffers
inp = input_ids[:, :-1].contiguous()
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
logits, _ = model(inp)
labels = input_ids[:, 1:].contiguous()
log_probs = F.log_softmax(logits.float(), dim=-1)
token_logps = log_probs.gather(2, labels.unsqueeze(2)).squeeze(2)
B, S = token_logps.shape
mask = torch.zeros_like(token_logps)
for b in range(B):
pl = prompt_lens[b].item()
response_start = max(0, pl - 1)
seq_len = (labels[b] != 0).sum().item()
mask[b, response_start:seq_len] = 1.0
return (token_logps * mask).sum(dim=1)
def dpo_loss(policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps, beta=0.1):
"""Compute DPO loss and metrics."""
chosen_rewards = beta * (policy_chosen_logps - ref_chosen_logps)
rejected_rewards = beta * (policy_rejected_logps - ref_rejected_logps)
logits = chosen_rewards - rejected_rewards
loss = -F.logsigmoid(logits).mean()
with torch.no_grad():
chosen_better = (chosen_rewards > rejected_rewards).float().mean()
reward_margin = (chosen_rewards - rejected_rewards).mean()
return loss, chosen_better.item(), reward_margin.item()
def main():
dist.init_process_group("nccl", timeout=datetime.timedelta(minutes=30))
rank = int(os.environ.get("RANK", 0))
local_rank = int(os.environ.get("LOCAL_RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
if rank == 0:
os.makedirs(DPO_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(LOG_DIR, exist_ok=True)
print("=" * 70)
print(" DPO: PREFERENCE ALIGNMENT FOR 1B TRANSFORMER")
print("=" * 70)
tokenizer = get_tokenizer()
special_tokens = ["<|user|>", "<|assistant|>", "<|end|>"]
vocab = tokenizer.get_vocab()
new_tokens = [t for t in special_tokens if t not in vocab]
if new_tokens:
tokenizer.add_tokens(new_tokens, special_tokens=True)
model_config = ModelConfig()
model_config.vocab_size = len(tokenizer)
if rank == 0:
print(f"[Init] Loading SFT model from {SFT_CHECKPOINT}")
# Policy model (trainable)
policy = Transformer(model_config)
ckpt = torch.load(SFT_CHECKPOINT, map_location="cpu", weights_only=False)
policy.load_state_dict(ckpt["model"])
sft_step = ckpt.get("step", 0)
if rank == 0:
print(f"[Init] SFT model loaded (step {sft_step})")
# Reference model (frozen copy)
ref_model = Transformer(model_config)
ref_model.load_state_dict(ckpt["model"])
del ckpt
policy = policy.to(device)
ref_model = ref_model.to(device).bfloat16()
ref_model.eval()
for p in ref_model.parameters():
p.requires_grad = False
policy = DDP(policy, device_ids=[local_rank])
if rank == 0:
n = sum(p.numel() for p in policy.parameters())
print(f"[Init] Params: {n:,} | GPUs: {world_size}x H100")
print(f"[Init] Beta: {BETA} | LR: {LEARNING_RATE}")
# Dataset
dataset = DPODataset(
tokenizer=tokenizer,
max_seq_len=MAX_SEQ_LEN,
split="train",
cache_dir=DATA_CACHE,
)
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=BATCH_SIZE_PER_GPU,
sampler=sampler,
num_workers=4,
pin_memory=True,
collate_fn=lambda b: dpo_collate_fn(b, pad_id=tokenizer.pad_token_id),
)
steps_per_epoch = len(dataloader) // GRADIENT_ACCUMULATION
total_steps = steps_per_epoch * NUM_EPOCHS
if rank == 0:
eff_batch = BATCH_SIZE_PER_GPU * world_size * GRADIENT_ACCUMULATION
print(f"[Init] Dataset: {len(dataset):,} preference pairs")
print(f"[Init] Effective batch: {eff_batch} | Steps/epoch: {steps_per_epoch}")
print(f"[Init] Total steps: {total_steps}")
print("-" * 70)
decay_params = [p for n, p in policy.named_parameters() if p.dim() >= 2 and p.requires_grad]
nodecay_params = [p for n, p in policy.named_parameters() if p.dim() < 2 and p.requires_grad]
optimizer = torch.optim.AdamW([
{"params": decay_params, "weight_decay": WEIGHT_DECAY},
{"params": nodecay_params, "weight_decay": 0.0},
], lr=LEARNING_RATE, betas=(0.9, 0.95), fused=True)
policy.train()
global_step = 0
running_loss = 0.0
running_acc = 0.0
running_margin = 0.0
t0 = time.time()
log_file = open(os.path.join(LOG_DIR, "dpo_log.jsonl"), "w") if rank == 0 else None
for epoch in range(NUM_EPOCHS):
sampler.set_epoch(epoch)
data_iter = iter(dataloader)
if rank == 0:
print(f"\n[Epoch {epoch + 1}/{NUM_EPOCHS}]")
while True:
optimizer.zero_grad(set_to_none=True)
batch_loss = 0.0
batch_acc = 0.0
batch_margin = 0.0
valid_micros = 0
for _ in range(GRADIENT_ACCUMULATION):
try:
batch = next(data_iter)
except StopIteration:
break
chosen_ids = batch["chosen_ids"].to(device, non_blocking=True)
rejected_ids = batch["rejected_ids"].to(device, non_blocking=True)
prompt_lens = batch["prompt_lens"].to(device, non_blocking=True)
policy_chosen_logps = get_per_token_logps(policy, chosen_ids, prompt_lens)
policy_rejected_logps = get_per_token_logps(policy, rejected_ids, prompt_lens)
with torch.no_grad():
ref_chosen_logps = get_per_token_logps(ref_model, chosen_ids, prompt_lens)
ref_rejected_logps = get_per_token_logps(ref_model, rejected_ids, prompt_lens)
loss, acc, margin = dpo_loss(
policy_chosen_logps, policy_rejected_logps,
ref_chosen_logps, ref_rejected_logps,
beta=BETA,
)
loss = loss / GRADIENT_ACCUMULATION
loss.backward()
batch_loss += loss.item()
batch_acc += acc
batch_margin += margin
valid_micros += 1
if valid_micros == 0:
break
torch.nn.utils.clip_grad_norm_(policy.parameters(), GRAD_CLIP)
lr = get_cosine_lr(global_step, WARMUP_STEPS, total_steps, LEARNING_RATE, MIN_LR)
for pg in optimizer.param_groups:
pg["lr"] = lr
optimizer.step()
global_step += 1
running_loss += batch_loss
running_acc += batch_acc / valid_micros
running_margin += batch_margin / valid_micros
if global_step % LOG_INTERVAL == 0:
avg_loss = running_loss / LOG_INTERVAL
avg_acc = running_acc / LOG_INTERVAL
avg_margin = running_margin / LOG_INTERVAL
elapsed = time.time() - t0
pct = 100.0 * global_step / total_steps
eta = (elapsed / max(global_step, 1)) * (total_steps - global_step)
if rank == 0:
gpu_mem = torch.cuda.max_memory_allocated(device) / 1e9
print(
f" [Step {global_step:>5d}/{total_steps}] "
f"loss={avg_loss:.4f} | acc={avg_acc:.1%} | "
f"margin={avg_margin:.3f} | lr={lr:.2e} | "
f"GPU={gpu_mem:.1f}GB | {pct:.1f}% | ETA={eta/60:.0f}m",
flush=True,
)
if log_file:
log_file.write(json.dumps({
"step": global_step, "loss": round(avg_loss, 4),
"accuracy": round(avg_acc, 4),
"reward_margin": round(avg_margin, 4),
"lr": lr, "elapsed_s": round(elapsed, 1),
}) + "\n")
log_file.flush()
running_loss = 0.0
running_acc = 0.0
running_margin = 0.0
if global_step % SAVE_INTERVAL == 0:
dist.barrier()
if rank == 0:
path = os.path.join(DPO_CHECKPOINT_DIR, f"dpo_step_{global_step}.pt")
torch.save({
"step": global_step,
"model": policy.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": model_config.vocab_size,
}, path)
print(f" >> Checkpoint: {path}", flush=True)
dist.barrier()
# Final save
dist.barrier()
if rank == 0:
final_path = os.path.join(DPO_CHECKPOINT_DIR, "dpo_final.pt")
torch.save({
"step": global_step,
"model": policy.module.state_dict(),
"config": model_config.__dict__,
"vocab_size": model_config.vocab_size,
}, final_path)
total_time = time.time() - t0
print("=" * 70)
print(f" DPO COMPLETE")
print(f" Steps: {global_step:,} | Epochs: {NUM_EPOCHS}")
print(f" Time: {total_time/60:.1f} minutes")
print(f" Final model: {final_path}")
print("=" * 70)
if log_file:
log_file.close()
dist.destroy_process_group()
if __name__ == "__main__":
main()