Files
stockmark-13b/notebooks/LoRA.ipynb
ModelHub XC 26a257b71d 初始化项目,由ModelHub XC社区提供模型
Model: stockmark/stockmark-13b
Source: Original Platform
2026-05-19 12:22:16 +08:00

236 lines
7.4 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"id": "1a884871-7a65-4501-9063-c85ad260d0da",
"metadata": {},
"source": [
"このnotebookはstockmark/stockmark-13bのモデルをkunishou/databricks-dolly-15k-jaのデータセットを用いてLoRA tuningするためのコードの例です。A100またはH100のGPUを用いることを想定しています。\n",
"\n",
"- モデルhttps://huggingface.co/stockmark/stockmark-13b\n",
"- データhttps://github.com/kunishou/databricks-dolly-15k-ja\n",
"\n",
"以下の例では、学習を1 epochを行います。A100 GPUで実行すると30分ほどかかります。\n",
"\n",
"また、ここで用いられているハイパーパラメータは最適化されたものではありませんので、必要に応じて調整してください。"
]
},
{
"cell_type": "markdown",
"id": "93b3f4b5-2825-4ef3-a0ee-7a60155aee5d",
"metadata": {},
"source": [
"# 準備"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6a694ba9-a0fa-4f14-81cf-f35f683ba889",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import datasets\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments\n",
"from peft import get_peft_model, LoraConfig, PeftModel, PeftConfig\n",
"\n",
"model_name = \"stockmark/stockmark-13b\"\n",
"peft_model_name = \"stockmark-13b-adapter\"\n",
"\n",
"prompt_template = \"\"\"### Instruction:\n",
"{instruction}\n",
"\n",
"### Input:\n",
"{input}\n",
"\n",
"### Response:\n",
"\"\"\"\n",
"\n",
"def encode(sample):\n",
" prompt = prompt_template.format(instruction=sample[\"instruction\"], input=sample[\"input\"])\n",
" target = sample[\"output\"]\n",
" input_ids_prompt, input_ids_target = tokenizer([prompt, target], add_special_tokens=False).input_ids\n",
" input_ids_prompt = [ tokenizer.bos_token_id ] + input_ids_prompt\n",
" input_ids_target = input_ids_target + [ tokenizer.eos_token_id ]\n",
" input_ids = input_ids_prompt + input_ids_target\n",
" labels = input_ids.copy()\n",
" labels[:len(input_ids_prompt)] = [-100] * len(input_ids_prompt) # ignore label tokens in a prompt for loss calculation\n",
" return {\"input_ids\": input_ids, \"labels\": labels}\n",
"\n",
"def get_collator(tokenizer, max_length):\n",
" def collator(batch):\n",
" batch = [{ key: value[:max_length] for key, value in sample.items() } for sample in batch ]\n",
" batch = tokenizer.pad(batch)\n",
" batch[\"labels\"] = [ e + [-100] * (len(batch[\"input_ids\"][0]) - len(e)) for e in batch[\"labels\"] ]\n",
" batch = { key: torch.tensor(value) for key, value in batch.items() }\n",
" return batch\n",
"\n",
" return collator"
]
},
{
"cell_type": "markdown",
"id": "51e6cfcf-1ac1-400e-a4bc-ea64375d0f9e",
"metadata": {},
"source": [
"# データセットとモデルのロード"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3ac80067-4e60-46c4-90da-05647cf96ccd",
"metadata": {},
"outputs": [],
"source": [
"# load_tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"\n",
"# prepare dataset\n",
"dataset_name = \"kunishou/databricks-dolly-15k-ja\"\n",
"dataset = datasets.load_dataset(dataset_name)\n",
"dataset = dataset.map(encode)\n",
"dataset = dataset[\"train\"].train_test_split(0.1)\n",
"train_dataset = dataset[\"train\"]\n",
"val_dataset = dataset[\"test\"]\n",
"\n",
"# load model\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
"\n",
"peft_config = LoraConfig(\n",
" task_type=\"CAUSAL_LM\",\n",
" inference_mode=False,\n",
" target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\"],\n",
" r=16,\n",
" lora_alpha=32,\n",
" lora_dropout=0.05\n",
")\n",
"\n",
"model = get_peft_model(model, peft_config)\n",
"model.print_trainable_parameters()"
]
},
{
"cell_type": "markdown",
"id": "9b471da0-7fba-4127-8b07-22da4cbee6a9",
"metadata": {},
"source": [
"# LoRA Tuning"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b9bafa12-538c-4abb-b8b3-bffeb0990b46",
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"./log_stockmark_13b\",\n",
" learning_rate=2e-4,\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=8,\n",
" per_device_eval_batch_size=16,\n",
" num_train_epochs=1,\n",
" logging_strategy='steps',\n",
" logging_steps=10,\n",
" save_strategy='epoch',\n",
" evaluation_strategy='epoch',\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"eval_loss\",\n",
" greater_is_better=False,\n",
" save_total_limit=2\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
" data_collator=get_collator(tokenizer, 320)\n",
")\n",
"\n",
"# LoRA tuning\n",
"trainer.train()\n",
"\n",
"# save model\n",
"model = trainer.model\n",
"model.save_pretrained(peft_model_name)"
]
},
{
"cell_type": "markdown",
"id": "a3f80a8e-1ac2-4bdc-8232-fe0ee18ffff5",
"metadata": {},
"source": [
"# 学習したモデルのロードOptional\n",
"異なるセッションでモデルを読み込む場合、まず最初の準備のセクションのコードを実行して、このコードを実行してください。"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43241395-3035-4cb9-8c1c-45ffe8cd48be",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
"model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=torch.bfloat16)\n",
"model = PeftModel.from_pretrained(model, peft_model_name)"
]
},
{
"cell_type": "markdown",
"id": "2ce4db1f-9bad-4c8e-9c04-d1102b299f24",
"metadata": {},
"source": [
"# 推論"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d7d6359b-e0ac-49df-a178-39bb9f79ca93",
"metadata": {},
"outputs": [],
"source": [
"prompt = prompt_template.format(instruction=\"自然言語処理とは?\", input=\"\")\n",
"\n",
"inputs = tokenizer(prompt, return_tensors=\"pt\").to(model.device)\n",
"with torch.no_grad():\n",
" tokens = model.generate(\n",
" **inputs,\n",
" max_new_tokens=128,\n",
" do_sample=True,\n",
" temperature=0.7\n",
" )\n",
"\n",
"output = tokenizer.decode(tokens[0], skip_special_tokens=True)\n",
"print(output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}