204 lines
5.9 KiB
Python
204 lines
5.9 KiB
Python
|
|
"""
|
||
|
|
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import Optional
|
||
|
|
|
||
|
|
import contextlib
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from datasets import load_dataset
|
||
|
|
from transformers import (
|
||
|
|
AutoModelForCausalLM,
|
||
|
|
AutoTokenizer,
|
||
|
|
AutoConfig,
|
||
|
|
Trainer,
|
||
|
|
TrainingArguments,
|
||
|
|
HfArgumentParser,
|
||
|
|
DataCollatorForLanguageModeling,
|
||
|
|
set_seed,
|
||
|
|
)
|
||
|
|
|
||
|
|
import deepspeed
|
||
|
|
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||
|
|
|
||
|
|
@contextlib.contextmanager
|
||
|
|
def _patched_no_sync(self):
|
||
|
|
try:
|
||
|
|
with _orig_no_sync(self):
|
||
|
|
yield
|
||
|
|
except AssertionError:
|
||
|
|
yield
|
||
|
|
|
||
|
|
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class ModelArguments:
|
||
|
|
model_name_or_path: str = field(
|
||
|
|
metadata={"help": "Path to pretrained model or model identifier"}
|
||
|
|
)
|
||
|
|
torch_dtype: Optional[str] = field(
|
||
|
|
default="bfloat16",
|
||
|
|
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class DataArguments:
|
||
|
|
data_path: str = field(
|
||
|
|
metadata={"help": "Path to training data (parquet file or directory)"}
|
||
|
|
)
|
||
|
|
max_seq_length: int = field(
|
||
|
|
default=4096,
|
||
|
|
metadata={"help": "Maximum sequence length for training"},
|
||
|
|
)
|
||
|
|
text_column: str = field(
|
||
|
|
default="text",
|
||
|
|
metadata={"help": "Name of the text column in the dataset"},
|
||
|
|
)
|
||
|
|
preprocessing_num_workers: int = field(
|
||
|
|
default=8,
|
||
|
|
metadata={"help": "Number of workers for data preprocessing"},
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def tokenize_and_group(dataset, tokenizer, data_args):
|
||
|
|
"""Tokenize texts and group into chunks of max_seq_length."""
|
||
|
|
|
||
|
|
column_names = dataset.column_names
|
||
|
|
text_column = data_args.text_column
|
||
|
|
if text_column not in column_names:
|
||
|
|
candidates = [c for c in column_names if "text" in c.lower()]
|
||
|
|
if candidates:
|
||
|
|
text_column = candidates[0]
|
||
|
|
else:
|
||
|
|
text_column = column_names[0]
|
||
|
|
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
|
||
|
|
|
||
|
|
def tokenize_function(examples):
|
||
|
|
return tokenizer(examples[text_column], add_special_tokens=False)
|
||
|
|
|
||
|
|
tokenized_dataset = dataset.map(
|
||
|
|
tokenize_function,
|
||
|
|
batched=True,
|
||
|
|
num_proc=data_args.preprocessing_num_workers,
|
||
|
|
remove_columns=column_names,
|
||
|
|
desc="Tokenizing",
|
||
|
|
)
|
||
|
|
|
||
|
|
block_size = data_args.max_seq_length
|
||
|
|
|
||
|
|
def group_texts(examples):
|
||
|
|
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
||
|
|
total_length = len(concatenated["input_ids"])
|
||
|
|
total_length = (total_length // block_size) * block_size
|
||
|
|
|
||
|
|
result = {
|
||
|
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||
|
|
for k, t in concatenated.items()
|
||
|
|
}
|
||
|
|
result["labels"] = result["input_ids"].copy()
|
||
|
|
return result
|
||
|
|
|
||
|
|
grouped_dataset = tokenized_dataset.map(
|
||
|
|
group_texts,
|
||
|
|
batched=True,
|
||
|
|
num_proc=data_args.preprocessing_num_workers,
|
||
|
|
desc="Grouping texts",
|
||
|
|
)
|
||
|
|
|
||
|
|
return grouped_dataset
|
||
|
|
|
||
|
|
|
||
|
|
def main():
|
||
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||
|
|
|
||
|
|
logging.basicConfig(
|
||
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||
|
|
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||
|
|
)
|
||
|
|
logger.info(f"Training args: {training_args}")
|
||
|
|
|
||
|
|
set_seed(training_args.seed)
|
||
|
|
|
||
|
|
dtype_map = {
|
||
|
|
"float16": torch.float16,
|
||
|
|
"bfloat16": torch.bfloat16,
|
||
|
|
"float32": torch.float32,
|
||
|
|
}
|
||
|
|
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||
|
|
|
||
|
|
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
||
|
|
model_args.model_name_or_path,
|
||
|
|
trust_remote_code=True,
|
||
|
|
)
|
||
|
|
if tokenizer.pad_token is None:
|
||
|
|
tokenizer.pad_token = tokenizer.eos_token
|
||
|
|
|
||
|
|
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(
|
||
|
|
model_args.model_name_or_path,
|
||
|
|
torch_dtype=torch_dtype,
|
||
|
|
trust_remote_code=True,
|
||
|
|
attn_implementation="sdpa",
|
||
|
|
)
|
||
|
|
model.config.use_cache = False
|
||
|
|
|
||
|
|
logger.info(f"Loading dataset from {data_args.data_path}")
|
||
|
|
if os.path.isfile(data_args.data_path):
|
||
|
|
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
|
||
|
|
elif os.path.isdir(data_args.data_path):
|
||
|
|
parquet_files = [
|
||
|
|
os.path.join(data_args.data_path, f)
|
||
|
|
for f in os.listdir(data_args.data_path)
|
||
|
|
if f.endswith(".parquet")
|
||
|
|
]
|
||
|
|
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Data path not found: {data_args.data_path}")
|
||
|
|
|
||
|
|
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||
|
|
|
||
|
|
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
|
||
|
|
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
|
||
|
|
|
||
|
|
data_collator = DataCollatorForLanguageModeling(
|
||
|
|
tokenizer=tokenizer,
|
||
|
|
mlm=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
trainer = Trainer(
|
||
|
|
model=model,
|
||
|
|
args=training_args,
|
||
|
|
train_dataset=train_dataset,
|
||
|
|
data_collator=data_collator,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("Starting training...")
|
||
|
|
train_result = trainer.train(
|
||
|
|
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||
|
|
)
|
||
|
|
|
||
|
|
trainer.save_model()
|
||
|
|
trainer.save_state()
|
||
|
|
|
||
|
|
metrics = train_result.metrics
|
||
|
|
metrics["train_samples"] = len(train_dataset)
|
||
|
|
trainer.log_metrics("train", metrics)
|
||
|
|
trainer.save_metrics("train", metrics)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|