初始化项目,由ModelHub XC社区提供模型
Model: openbmb/BitCPM-CANN-8B-unquantized Source: Original Platform
This commit is contained in:
203
example/train.py
Normal file
203
example/train.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
HfArgumentParser,
|
||||
DataCollatorForLanguageModeling,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(
|
||||
metadata={"help": "Path to training data (parquet file or directory)"}
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
text_column: str = field(
|
||||
default="text",
|
||||
metadata={"help": "Name of the text column in the dataset"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
def tokenize_and_group(dataset, tokenizer, data_args):
|
||||
"""Tokenize texts and group into chunks of max_seq_length."""
|
||||
|
||||
column_names = dataset.column_names
|
||||
text_column = data_args.text_column
|
||||
if text_column not in column_names:
|
||||
candidates = [c for c in column_names if "text" in c.lower()]
|
||||
if candidates:
|
||||
text_column = candidates[0]
|
||||
else:
|
||||
text_column = column_names[0]
|
||||
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
|
||||
|
||||
def tokenize_function(examples):
|
||||
return tokenizer(examples[text_column], add_special_tokens=False)
|
||||
|
||||
tokenized_dataset = dataset.map(
|
||||
tokenize_function,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
block_size = data_args.max_seq_length
|
||||
|
||||
def group_texts(examples):
|
||||
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
||||
total_length = len(concatenated["input_ids"])
|
||||
total_length = (total_length // block_size) * block_size
|
||||
|
||||
result = {
|
||||
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
||||
for k, t in concatenated.items()
|
||||
}
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
grouped_dataset = tokenized_dataset.map(
|
||||
group_texts,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
desc="Grouping texts",
|
||||
)
|
||||
|
||||
return grouped_dataset
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading dataset from {data_args.data_path}")
|
||||
if os.path.isfile(data_args.data_path):
|
||||
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
|
||||
elif os.path.isdir(data_args.data_path):
|
||||
parquet_files = [
|
||||
os.path.join(data_args.data_path, f)
|
||||
for f in os.listdir(data_args.data_path)
|
||||
if f.endswith(".parquet")
|
||||
]
|
||||
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
|
||||
else:
|
||||
raise ValueError(f"Data path not found: {data_args.data_path}")
|
||||
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
|
||||
|
||||
data_collator = DataCollatorForLanguageModeling(
|
||||
tokenizer=tokenizer,
|
||||
mlm=False,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=data_collator,
|
||||
)
|
||||
|
||||
logger.info("Starting training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user