Files
gpt-neox-japanese-1.4b/notebooks/LoRA.ipynb
ModelHub XC b737a3bd5d 初始化项目,由ModelHub XC社区提供模型
Model: stockmark/gpt-neox-japanese-1.4b
Source: Original Platform
2026-06-08 22:03:15 +08:00

253 lines
7.6 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"source": [
"このnotebookは`stockmark/gpt-neox-japanese-1.4b`のモデルを`kunishou/databricks-dolly-15k-ja`のデータセットを用いてLoRA tuningするためのコードの例です。以下の例では、学習を1 epochを行います。T4 GPUで実行すると30分ほどかかります。\n",
"\n",
"- モデルhttps://huggingface.co/stockmark/gpt-neox-japanese-1.4b\n",
"- データhttps://github.com/kunishou/databricks-dolly-15k-ja\n",
"\n",
"\n",
"また、ここで用いている設定は暫定的なもので、必要に応じて調整してください。"
],
"metadata": {
"id": "BPGgCZtMdMsv"
}
},
{
"cell_type": "markdown",
"source": [
"# ライブラリのインストール"
],
"metadata": {
"id": "hCZH9e6EcZyj"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cmn52bx3v5Ha"
},
"outputs": [],
"source": [
"!python3 -m pip install -U pip\n",
"!python3 -m pip install transformers accelerate datasets peft"
]
},
{
"cell_type": "markdown",
"source": [
"# 準備"
],
"metadata": {
"id": "4t3Cqs9_ce3J"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"import datasets\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n",
"from peft import get_peft_model, LoraConfig, TaskType, PeftModel, PeftConfig\n",
"\n",
"model_name = \"stockmark/gpt-neox-japanese-1.4b\"\n",
"peft_model_name = \"peft_model\"\n",
"\n",
"prompt_template = \"\"\"### Instruction:\n",
"{instruction}\n",
"\n",
"### Input:\n",
"{input}\n",
"\n",
"### Response:\n",
"\"\"\"\n",
"\n",
"def encode(sample):\n",
" prompt = prompt_template.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n",
" target = sample[\"output\"] + tokenizer.eos_token\n",
" input_ids_prompt, input_ids_target = tokenizer([prompt, target]).input_ids\n",
" input_ids = input_ids_prompt + input_ids_target\n",
" labels = input_ids.copy()\n",
" labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt)\n",
" return {\"input_ids\": input_ids, \"labels\": labels}\n",
"\n",
"def get_collator(tokenizer, max_length):\n",
" def collator(batch):\n",
" batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n",
" batch = tokenizer.pad(batch, padding=True)\n",
" batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n",
" batch = { key: torch.tensor(value) for key, value in batch.items() }\n",
" return batch\n",
"\n",
" return collator\n"
],
"metadata": {
"id": "hNdYMGMRzAVn"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# データセットとモデルの準備\n"
],
"metadata": {
"id": "UqXxPjJ_cliu"
}
},
{
"cell_type": "code",
"source": [
"# prepare dataset\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
"dataset = datasets.load_dataset(dataset_name)\n",
"dataset = dataset.map(encode)\n",
"dataset = dataset[\"train\"].train_test_split(0.2)\n",
"train_dataset = dataset[\"train\"]\n",
"val_dataset = dataset[\"test\"]\n",
"\n",
"# load model\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n",
"\n",
"peft_config = LoraConfig(\n",
" task_type=TaskType.CAUSAL_LM,\n",
" inference_mode=False,\n",
" target_modules=[\"query_key_value\"],\n",
" r=16,\n",
" lora_alpha=32,\n",
" lora_dropout=0.05\n",
")\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()"
],
"metadata": {
"id": "ZWdN-p7t0Grk"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# LoRA tuning"
],
"metadata": {
"id": "XCrdVAJYc88c"
}
},
{
"cell_type": "code",
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"./train_results\",\n",
" learning_rate=2e-4,\n",
" per_device_train_batch_size=4,\n",
" gradient_accumulation_steps=4,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=1,\n",
" logging_strategy='steps',\n",
" logging_steps=10,\n",
" save_strategy='epoch',\n",
" evaluation_strategy='epoch',\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"eval_loss\",\n",
" greater_is_better=False,\n",
" save_total_limit=2\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" data_collator=get_collator(tokenizer, 512)\n",
")\n",
"\n",
"trainer.train()\n",
"model = trainer.model\n",
"model.save_pretrained(peft_model_name)"
],
"metadata": {
"id": "4LH9tOCTJVk1"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 学習したモデルのロード"
],
"metadata": {
"id": "ORgzOPAqdEZR"
}
},
{
"cell_type": "code",
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map={\"\": 0}, torch_dtype=torch.float16)\n",
"model = PeftModel.from_pretrained(model, peft_model_name)"
],
"metadata": {
"id": "yrExyO9EOvzR"
},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# 推論"
],
"metadata": {
"id": "-dttR6tkdG0k"
}
},
{
"cell_type": "code",
"source": [
"prompt = prompt_template.format(instruction=\"日本で人気のスポーツは?\", input=\"\")\n",
"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
"with torch.no_grad():\n",
" tokens = model.generate(\n",
" **inputs,\n",
" max_new_tokens=128,\n",
" repetition_penalty=1.1\n",
" )\n",
"\n",
"output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
"print(output)"
],
"metadata": {
"id": "pC5t9F1GJuFN"
},
"execution_count": null,
"outputs": []
}
]
}