初始化项目,由ModelHub XC社区提供模型
Model: openbmb/BitCPM-CANN-3B-unquantized Source: Original Platform
This commit is contained in:
424
example/train_sft.py
Normal file
424
example/train_sft.py
Normal file
@@ -0,0 +1,424 @@
|
||||
"""
|
||||
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import contextlib
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
HfArgumentParser,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
import deepspeed
|
||||
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _patched_no_sync(self):
|
||||
try:
|
||||
with _orig_no_sync(self):
|
||||
yield
|
||||
except AssertionError:
|
||||
yield
|
||||
|
||||
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(
|
||||
metadata={"help": "Path to pretrained model or model identifier"}
|
||||
)
|
||||
torch_dtype: Optional[str] = field(
|
||||
default="bfloat16",
|
||||
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
|
||||
max_seq_length: int = field(
|
||||
default=4096,
|
||||
metadata={"help": "Maximum sequence length for training"},
|
||||
)
|
||||
prompt_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
|
||||
)
|
||||
input_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional extra input/context column name"},
|
||||
)
|
||||
response_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Response/output column name. Auto-detected if omitted."},
|
||||
)
|
||||
messages_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
|
||||
)
|
||||
system_column: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optional system prompt column name"},
|
||||
)
|
||||
train_on_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to compute loss on prompt/user tokens"},
|
||||
)
|
||||
add_eos_token: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Append eos_token to plain prompt/response examples"},
|
||||
)
|
||||
preprocessing_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of workers for data preprocessing"},
|
||||
)
|
||||
|
||||
|
||||
class SFTDataCollator:
|
||||
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
|
||||
self.tokenizer = tokenizer
|
||||
self.pad_to_multiple_of = pad_to_multiple_of
|
||||
|
||||
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
||||
max_length = max(len(feature["input_ids"]) for feature in features)
|
||||
if self.pad_to_multiple_of:
|
||||
multiple = self.pad_to_multiple_of
|
||||
max_length = ((max_length + multiple - 1) // multiple) * multiple
|
||||
|
||||
input_ids = []
|
||||
attention_mask = []
|
||||
labels = []
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
|
||||
for feature in features:
|
||||
length = len(feature["input_ids"])
|
||||
pad_length = max_length - length
|
||||
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
|
||||
attention_mask.append([1] * length + [0] * pad_length)
|
||||
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
|
||||
|
||||
return {
|
||||
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
||||
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
||||
"labels": torch.tensor(labels, dtype=torch.long),
|
||||
}
|
||||
|
||||
|
||||
def load_sft_dataset(data_path: str):
|
||||
if os.path.isfile(data_path):
|
||||
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
|
||||
if extension == "jsonl":
|
||||
extension = "json"
|
||||
if extension not in {"parquet", "json", "csv", "txt"}:
|
||||
raise ValueError(f"Unsupported data file extension: {extension}")
|
||||
return load_dataset(extension, data_files=data_path, split="train")
|
||||
|
||||
if os.path.isdir(data_path):
|
||||
data_files = []
|
||||
extension = None
|
||||
for name in os.listdir(data_path):
|
||||
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
|
||||
if current_extension == "jsonl":
|
||||
current_extension = "json"
|
||||
if current_extension in {"parquet", "json", "csv", "txt"}:
|
||||
extension = extension or current_extension
|
||||
if current_extension == extension:
|
||||
data_files.append(os.path.join(data_path, name))
|
||||
if not data_files or extension is None:
|
||||
raise ValueError(f"No supported data files found in: {data_path}")
|
||||
return load_dataset(extension, data_files=sorted(data_files), split="train")
|
||||
|
||||
raise ValueError(f"Data path not found: {data_path}")
|
||||
|
||||
|
||||
def choose_column(
|
||||
column_names: List[str], explicit: Optional[str], candidates: List[str]
|
||||
) -> Optional[str]:
|
||||
if explicit:
|
||||
if explicit not in column_names:
|
||||
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
|
||||
return explicit
|
||||
for name in candidates:
|
||||
if name in column_names:
|
||||
return name
|
||||
return None
|
||||
|
||||
|
||||
def parse_messages(value: Any) -> List[Dict[str, str]]:
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("messages/conversations column must be a list or JSON string")
|
||||
|
||||
messages = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError("Each message must be a dict")
|
||||
|
||||
role = item.get("role", item.get("from"))
|
||||
content = item.get("content", item.get("value"))
|
||||
if role == "human":
|
||||
role = "user"
|
||||
elif role == "gpt":
|
||||
role = "assistant"
|
||||
|
||||
if role is None or content is None:
|
||||
raise ValueError("Each message must contain role/from and content/value")
|
||||
messages.append({"role": str(role), "content": str(content)})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def tokenize_text(tokenizer, text: str) -> List[int]:
|
||||
return tokenizer(text, add_special_tokens=False)["input_ids"]
|
||||
|
||||
|
||||
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
|
||||
if tokenizer.chat_template is None:
|
||||
raise ValueError(
|
||||
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
|
||||
)
|
||||
return tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
)
|
||||
|
||||
|
||||
def encode_prompt_response(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
prompt_column: str,
|
||||
input_column: Optional[str],
|
||||
response_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
prompt = str(example[prompt_column])
|
||||
if input_column and example.get(input_column):
|
||||
prompt = prompt + "\n" + str(example[input_column])
|
||||
response = str(example[response_column])
|
||||
|
||||
messages = []
|
||||
if data_args.system_column and example.get(data_args.system_column):
|
||||
messages.append({"role": "system", "content": str(example[data_args.system_column])})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt_text))
|
||||
else:
|
||||
response_text = response
|
||||
if data_args.add_eos_token and tokenizer.eos_token:
|
||||
response_text += tokenizer.eos_token
|
||||
full_text = prompt + "\n" + response_text
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
|
||||
|
||||
labels = input_ids.copy()
|
||||
if not data_args.train_on_prompt:
|
||||
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def encode_messages(
|
||||
example: Dict[str, Any],
|
||||
tokenizer,
|
||||
data_args: DataArguments,
|
||||
messages_column: str,
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
messages = parse_messages(example[messages_column])
|
||||
|
||||
if tokenizer.chat_template is not None:
|
||||
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
||||
input_ids = tokenize_text(tokenizer, full_text)
|
||||
labels = [IGNORE_INDEX] * len(input_ids)
|
||||
|
||||
if data_args.train_on_prompt:
|
||||
labels = input_ids.copy()
|
||||
else:
|
||||
for index, message in enumerate(messages):
|
||||
if message["role"] != "assistant":
|
||||
continue
|
||||
before_text = apply_chat_template(
|
||||
tokenizer, messages[:index], add_generation_prompt=True
|
||||
)
|
||||
after_text = apply_chat_template(
|
||||
tokenizer, messages[: index + 1], add_generation_prompt=False
|
||||
)
|
||||
start = len(tokenize_text(tokenizer, before_text))
|
||||
end = len(tokenize_text(tokenizer, after_text))
|
||||
labels[start:end] = input_ids[start:end]
|
||||
else:
|
||||
labels = []
|
||||
input_ids = []
|
||||
for message in messages:
|
||||
part = f"{message['role']}: {message['content']}\n"
|
||||
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
|
||||
part += tokenizer.eos_token
|
||||
part_ids = tokenize_text(tokenizer, part)
|
||||
input_ids.extend(part_ids)
|
||||
if data_args.train_on_prompt or message["role"] == "assistant":
|
||||
labels.extend(part_ids)
|
||||
else:
|
||||
labels.extend([IGNORE_INDEX] * len(part_ids))
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
|
||||
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
|
||||
column_names = raw_dataset.column_names
|
||||
messages_column = choose_column(
|
||||
column_names, data_args.messages_column, ["messages", "conversations"]
|
||||
)
|
||||
prompt_column = choose_column(
|
||||
column_names,
|
||||
data_args.prompt_column,
|
||||
["prompt", "instruction", "question"],
|
||||
)
|
||||
input_column = choose_column(
|
||||
column_names,
|
||||
data_args.input_column,
|
||||
["input", "context"],
|
||||
)
|
||||
response_column = choose_column(
|
||||
column_names,
|
||||
data_args.response_column,
|
||||
["response", "output", "answer", "chosen"],
|
||||
)
|
||||
|
||||
if messages_column:
|
||||
logger.info(f"Using chat messages column: {messages_column}")
|
||||
elif prompt_column and response_column:
|
||||
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Cannot infer SFT data format. Provide either messages/conversations or "
|
||||
"prompt/instruction plus response/output columns."
|
||||
)
|
||||
|
||||
def encode_batch(examples):
|
||||
batch_input_ids = []
|
||||
batch_labels = []
|
||||
batch_attention_mask = []
|
||||
|
||||
batch_size = len(next(iter(examples.values())))
|
||||
for i in range(batch_size):
|
||||
example = {name: values[i] for name, values in examples.items()}
|
||||
if messages_column:
|
||||
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
|
||||
else:
|
||||
input_ids, labels = encode_prompt_response(
|
||||
example, tokenizer, data_args, prompt_column, input_column, response_column
|
||||
)
|
||||
|
||||
input_ids = input_ids[: data_args.max_seq_length]
|
||||
labels = labels[: data_args.max_seq_length]
|
||||
if not input_ids or all(label == IGNORE_INDEX for label in labels):
|
||||
continue
|
||||
|
||||
batch_input_ids.append(input_ids)
|
||||
batch_labels.append(labels)
|
||||
batch_attention_mask.append([1] * len(input_ids))
|
||||
|
||||
return {
|
||||
"input_ids": batch_input_ids,
|
||||
"attention_mask": batch_attention_mask,
|
||||
"labels": batch_labels,
|
||||
}
|
||||
|
||||
return raw_dataset.map(
|
||||
encode_batch,
|
||||
batched=True,
|
||||
num_proc=data_args.preprocessing_num_workers,
|
||||
remove_columns=column_names,
|
||||
desc="Tokenizing SFT data",
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
||||
)
|
||||
logger.info(f"Training args: {training_args}")
|
||||
|
||||
set_seed(training_args.seed)
|
||||
|
||||
dtype_map = {
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
}
|
||||
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
||||
|
||||
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_args.model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="sdpa",
|
||||
)
|
||||
model.config.use_cache = False
|
||||
|
||||
logger.info(f"Loading SFT dataset from {data_args.data_path}")
|
||||
raw_dataset = load_sft_dataset(data_args.data_path)
|
||||
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
||||
|
||||
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
|
||||
logger.info(f"Processed dataset: {len(train_dataset)} samples")
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
data_collator=SFTDataCollator(tokenizer),
|
||||
)
|
||||
|
||||
logger.info("Starting SFT training...")
|
||||
train_result = trainer.train(
|
||||
resume_from_checkpoint=training_args.resume_from_checkpoint
|
||||
)
|
||||
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
metrics = train_result.metrics
|
||||
metrics["train_samples"] = len(train_dataset)
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user