425 lines
14 KiB
Python
425 lines
14 KiB
Python
"""
|
|
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
|
|
"""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Dict, List, Optional, Tuple
|
|
|
|
import contextlib
|
|
|
|
import torch
|
|
from datasets import load_dataset
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
AutoTokenizer,
|
|
HfArgumentParser,
|
|
Trainer,
|
|
TrainingArguments,
|
|
set_seed,
|
|
)
|
|
|
|
import deepspeed
|
|
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
|
|
|
@contextlib.contextmanager
|
|
def _patched_no_sync(self):
|
|
try:
|
|
with _orig_no_sync(self):
|
|
yield
|
|
except AssertionError:
|
|
yield
|
|
|
|
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
IGNORE_INDEX = -100
|
|
|
|
|
|
@dataclass
|
|
class ModelArguments:
|
|
model_name_or_path: str = field(
|
|
metadata={"help": "Path to pretrained model or model identifier"}
|
|
)
|
|
torch_dtype: Optional[str] = field(
|
|
default="bfloat16",
|
|
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class DataArguments:
|
|
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
|
|
max_seq_length: int = field(
|
|
default=4096,
|
|
metadata={"help": "Maximum sequence length for training"},
|
|
)
|
|
prompt_column: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
|
|
)
|
|
input_column: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Optional extra input/context column name"},
|
|
)
|
|
response_column: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Response/output column name. Auto-detected if omitted."},
|
|
)
|
|
messages_column: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
|
|
)
|
|
system_column: Optional[str] = field(
|
|
default=None,
|
|
metadata={"help": "Optional system prompt column name"},
|
|
)
|
|
train_on_prompt: bool = field(
|
|
default=False,
|
|
metadata={"help": "Whether to compute loss on prompt/user tokens"},
|
|
)
|
|
add_eos_token: bool = field(
|
|
default=True,
|
|
metadata={"help": "Append eos_token to plain prompt/response examples"},
|
|
)
|
|
preprocessing_num_workers: int = field(
|
|
default=8,
|
|
metadata={"help": "Number of workers for data preprocessing"},
|
|
)
|
|
|
|
|
|
class SFTDataCollator:
|
|
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
|
|
self.tokenizer = tokenizer
|
|
self.pad_to_multiple_of = pad_to_multiple_of
|
|
|
|
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
|
max_length = max(len(feature["input_ids"]) for feature in features)
|
|
if self.pad_to_multiple_of:
|
|
multiple = self.pad_to_multiple_of
|
|
max_length = ((max_length + multiple - 1) // multiple) * multiple
|
|
|
|
input_ids = []
|
|
attention_mask = []
|
|
labels = []
|
|
pad_token_id = self.tokenizer.pad_token_id
|
|
|
|
for feature in features:
|
|
length = len(feature["input_ids"])
|
|
pad_length = max_length - length
|
|
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
|
|
attention_mask.append([1] * length + [0] * pad_length)
|
|
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
|
|
|
|
return {
|
|
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
|
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
|
"labels": torch.tensor(labels, dtype=torch.long),
|
|
}
|
|
|
|
|
|
def load_sft_dataset(data_path: str):
|
|
if os.path.isfile(data_path):
|
|
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
|
|
if extension == "jsonl":
|
|
extension = "json"
|
|
if extension not in {"parquet", "json", "csv", "txt"}:
|
|
raise ValueError(f"Unsupported data file extension: {extension}")
|
|
return load_dataset(extension, data_files=data_path, split="train")
|
|
|
|
if os.path.isdir(data_path):
|
|
data_files = []
|
|
extension = None
|
|
for name in os.listdir(data_path):
|
|
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
|
|
if current_extension == "jsonl":
|
|
current_extension = "json"
|
|
if current_extension in {"parquet", "json", "csv", "txt"}:
|
|
extension = extension or current_extension
|
|
if current_extension == extension:
|
|
data_files.append(os.path.join(data_path, name))
|
|
if not data_files or extension is None:
|
|
raise ValueError(f"No supported data files found in: {data_path}")
|
|
return load_dataset(extension, data_files=sorted(data_files), split="train")
|
|
|
|
raise ValueError(f"Data path not found: {data_path}")
|
|
|
|
|
|
def choose_column(
|
|
column_names: List[str], explicit: Optional[str], candidates: List[str]
|
|
) -> Optional[str]:
|
|
if explicit:
|
|
if explicit not in column_names:
|
|
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
|
|
return explicit
|
|
for name in candidates:
|
|
if name in column_names:
|
|
return name
|
|
return None
|
|
|
|
|
|
def parse_messages(value: Any) -> List[Dict[str, str]]:
|
|
if isinstance(value, str):
|
|
value = json.loads(value)
|
|
if not isinstance(value, list):
|
|
raise ValueError("messages/conversations column must be a list or JSON string")
|
|
|
|
messages = []
|
|
for item in value:
|
|
if not isinstance(item, dict):
|
|
raise ValueError("Each message must be a dict")
|
|
|
|
role = item.get("role", item.get("from"))
|
|
content = item.get("content", item.get("value"))
|
|
if role == "human":
|
|
role = "user"
|
|
elif role == "gpt":
|
|
role = "assistant"
|
|
|
|
if role is None or content is None:
|
|
raise ValueError("Each message must contain role/from and content/value")
|
|
messages.append({"role": str(role), "content": str(content)})
|
|
|
|
return messages
|
|
|
|
|
|
def tokenize_text(tokenizer, text: str) -> List[int]:
|
|
return tokenizer(text, add_special_tokens=False)["input_ids"]
|
|
|
|
|
|
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
|
|
if tokenizer.chat_template is None:
|
|
raise ValueError(
|
|
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
|
|
)
|
|
return tokenizer.apply_chat_template(
|
|
messages,
|
|
tokenize=False,
|
|
add_generation_prompt=add_generation_prompt,
|
|
)
|
|
|
|
|
|
def encode_prompt_response(
|
|
example: Dict[str, Any],
|
|
tokenizer,
|
|
data_args: DataArguments,
|
|
prompt_column: str,
|
|
input_column: Optional[str],
|
|
response_column: str,
|
|
) -> Tuple[List[int], List[int]]:
|
|
prompt = str(example[prompt_column])
|
|
if input_column and example.get(input_column):
|
|
prompt = prompt + "\n" + str(example[input_column])
|
|
response = str(example[response_column])
|
|
|
|
messages = []
|
|
if data_args.system_column and example.get(data_args.system_column):
|
|
messages.append({"role": "system", "content": str(example[data_args.system_column])})
|
|
messages.append({"role": "user", "content": prompt})
|
|
messages.append({"role": "assistant", "content": response})
|
|
|
|
if tokenizer.chat_template is not None:
|
|
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
|
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
|
|
input_ids = tokenize_text(tokenizer, full_text)
|
|
prompt_length = len(tokenize_text(tokenizer, prompt_text))
|
|
else:
|
|
response_text = response
|
|
if data_args.add_eos_token and tokenizer.eos_token:
|
|
response_text += tokenizer.eos_token
|
|
full_text = prompt + "\n" + response_text
|
|
input_ids = tokenize_text(tokenizer, full_text)
|
|
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
|
|
|
|
labels = input_ids.copy()
|
|
if not data_args.train_on_prompt:
|
|
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
|
|
return input_ids, labels
|
|
|
|
|
|
def encode_messages(
|
|
example: Dict[str, Any],
|
|
tokenizer,
|
|
data_args: DataArguments,
|
|
messages_column: str,
|
|
) -> Tuple[List[int], List[int]]:
|
|
messages = parse_messages(example[messages_column])
|
|
|
|
if tokenizer.chat_template is not None:
|
|
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
|
input_ids = tokenize_text(tokenizer, full_text)
|
|
labels = [IGNORE_INDEX] * len(input_ids)
|
|
|
|
if data_args.train_on_prompt:
|
|
labels = input_ids.copy()
|
|
else:
|
|
for index, message in enumerate(messages):
|
|
if message["role"] != "assistant":
|
|
continue
|
|
before_text = apply_chat_template(
|
|
tokenizer, messages[:index], add_generation_prompt=True
|
|
)
|
|
after_text = apply_chat_template(
|
|
tokenizer, messages[: index + 1], add_generation_prompt=False
|
|
)
|
|
start = len(tokenize_text(tokenizer, before_text))
|
|
end = len(tokenize_text(tokenizer, after_text))
|
|
labels[start:end] = input_ids[start:end]
|
|
else:
|
|
labels = []
|
|
input_ids = []
|
|
for message in messages:
|
|
part = f"{message['role']}: {message['content']}\n"
|
|
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
|
|
part += tokenizer.eos_token
|
|
part_ids = tokenize_text(tokenizer, part)
|
|
input_ids.extend(part_ids)
|
|
if data_args.train_on_prompt or message["role"] == "assistant":
|
|
labels.extend(part_ids)
|
|
else:
|
|
labels.extend([IGNORE_INDEX] * len(part_ids))
|
|
|
|
return input_ids, labels
|
|
|
|
|
|
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
|
|
column_names = raw_dataset.column_names
|
|
messages_column = choose_column(
|
|
column_names, data_args.messages_column, ["messages", "conversations"]
|
|
)
|
|
prompt_column = choose_column(
|
|
column_names,
|
|
data_args.prompt_column,
|
|
["prompt", "instruction", "question"],
|
|
)
|
|
input_column = choose_column(
|
|
column_names,
|
|
data_args.input_column,
|
|
["input", "context"],
|
|
)
|
|
response_column = choose_column(
|
|
column_names,
|
|
data_args.response_column,
|
|
["response", "output", "answer", "chosen"],
|
|
)
|
|
|
|
if messages_column:
|
|
logger.info(f"Using chat messages column: {messages_column}")
|
|
elif prompt_column and response_column:
|
|
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
|
|
else:
|
|
raise ValueError(
|
|
"Cannot infer SFT data format. Provide either messages/conversations or "
|
|
"prompt/instruction plus response/output columns."
|
|
)
|
|
|
|
def encode_batch(examples):
|
|
batch_input_ids = []
|
|
batch_labels = []
|
|
batch_attention_mask = []
|
|
|
|
batch_size = len(next(iter(examples.values())))
|
|
for i in range(batch_size):
|
|
example = {name: values[i] for name, values in examples.items()}
|
|
if messages_column:
|
|
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
|
|
else:
|
|
input_ids, labels = encode_prompt_response(
|
|
example, tokenizer, data_args, prompt_column, input_column, response_column
|
|
)
|
|
|
|
input_ids = input_ids[: data_args.max_seq_length]
|
|
labels = labels[: data_args.max_seq_length]
|
|
if not input_ids or all(label == IGNORE_INDEX for label in labels):
|
|
continue
|
|
|
|
batch_input_ids.append(input_ids)
|
|
batch_labels.append(labels)
|
|
batch_attention_mask.append([1] * len(input_ids))
|
|
|
|
return {
|
|
"input_ids": batch_input_ids,
|
|
"attention_mask": batch_attention_mask,
|
|
"labels": batch_labels,
|
|
}
|
|
|
|
return raw_dataset.map(
|
|
encode_batch,
|
|
batched=True,
|
|
num_proc=data_args.preprocessing_num_workers,
|
|
remove_columns=column_names,
|
|
desc="Tokenizing SFT data",
|
|
)
|
|
|
|
|
|
def main():
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
|
|
|
logging.basicConfig(
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
|
)
|
|
logger.info(f"Training args: {training_args}")
|
|
|
|
set_seed(training_args.seed)
|
|
|
|
dtype_map = {
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
"float32": torch.float32,
|
|
}
|
|
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
|
|
|
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
trust_remote_code=True,
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
torch_dtype=torch_dtype,
|
|
trust_remote_code=True,
|
|
attn_implementation="sdpa",
|
|
)
|
|
model.config.use_cache = False
|
|
|
|
logger.info(f"Loading SFT dataset from {data_args.data_path}")
|
|
raw_dataset = load_sft_dataset(data_args.data_path)
|
|
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
|
|
|
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
|
|
logger.info(f"Processed dataset: {len(train_dataset)} samples")
|
|
|
|
trainer = Trainer(
|
|
model=model,
|
|
args=training_args,
|
|
train_dataset=train_dataset,
|
|
data_collator=SFTDataCollator(tokenizer),
|
|
)
|
|
|
|
logger.info("Starting SFT training...")
|
|
train_result = trainer.train(
|
|
resume_from_checkpoint=training_args.resume_from_checkpoint
|
|
)
|
|
|
|
trainer.save_model()
|
|
trainer.save_state()
|
|
|
|
metrics = train_result.metrics
|
|
metrics["train_samples"] = len(train_dataset)
|
|
trainer.log_metrics("train", metrics)
|
|
trainer.save_metrics("train", metrics)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|