初始化项目,由ModelHub XC社区提供模型
Model: SPAISS6F1/qwen-1b-pruned-th Source: Original Platform
This commit is contained in:
125
README.md
Normal file
125
README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
language: [th]
|
||||
library_name: transformers
|
||||
pipeline_tag: text-generation
|
||||
tags: [thai, pruning, depth-pruning, layer-dropping]
|
||||
---
|
||||
|
||||
# qwen-1b-pruned-th
|
||||
|
||||
โมเดลภาษาไทยขนาดเล็กที่ได้จากการ **Depth Pruning (Layer Dropping)** ของ `Qwen/Qwen2.5-3B`
|
||||
แล้วทำ **Healing SFT** เพื่อกู้ความสามารถกลับมา
|
||||
|
||||
| | |
|
||||
|---|---|
|
||||
| Base model | `Qwen/Qwen2.5-3B` |
|
||||
| Base size | 3.09B (36 layers) |
|
||||
| **โมเดลนี้** | **1.70B** (เก็บ 18 layers) |
|
||||
| Layers ที่เก็บ | [0-8, 27-35] (ตัด layer กลาง เก็บหัว+ท้าย) |
|
||||
| Healing data | SEA-PILE v2 Thai (~8,000 docs) |
|
||||
| Hardware | NVIDIA A100-40GB (Lanta HPC) |
|
||||
| Requires | `transformers>=4.44`, accelerate |
|
||||
|
||||
---
|
||||
|
||||
## Pipeline การสร้างโมเดล (ทำซ้ำได้)
|
||||
|
||||
### ขั้นที่ 1 — Depth Pruning (Layer Dropping)
|
||||
ตัด decoder layer ตรงกลางทิ้ง (มักทำงานซ้ำซ้อน) เก็บเฉพาะ layer หัว (เข้าใจ input)
|
||||
และ layer ท้าย (สร้าง output) — **embedding / lm_head / norm คงเดิม** จึงไม่พัง dimension
|
||||
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-3B", torch_dtype=torch.bfloat16)
|
||||
|
||||
# หา text-decoder layer list (เลี่ยง vision encoder กรณี multimodal)
|
||||
holder, layers = None, None
|
||||
for _, mod in model.named_modules():
|
||||
L = getattr(mod, "layers", None)
|
||||
if isinstance(L, torch.nn.ModuleList) and len(L) and hasattr(L[0], "self_attn"):
|
||||
holder, layers = mod, L
|
||||
if "language" in _.lower() or "text" in _.lower():
|
||||
break
|
||||
N = len(layers)
|
||||
|
||||
# เก็บ 18 จาก 36 layers: หัว + ท้าย
|
||||
keep = [0,1,2,3,4,5,6,7,8, 27,28,29,30,31,32,33,34,35]
|
||||
holder.layers = torch.nn.ModuleList([layers[i] for i in keep])
|
||||
|
||||
# อัปเดต config (รองรับ nested text_config ของ Gemma3)
|
||||
for c in {model.config, getattr(model.config, "text_config", model.config)}:
|
||||
if getattr(c, "num_hidden_layers", None) is not None:
|
||||
c.num_hidden_layers = len(keep)
|
||||
lt = getattr(c, "layer_types", None)
|
||||
if isinstance(lt, list) and len(lt) == N:
|
||||
c.layer_types = [lt[i] for i in keep]
|
||||
|
||||
# reindex layer_idx ของแต่ละ block (สำคัญต่อ KV cache)
|
||||
for i, lyr in enumerate(holder.layers):
|
||||
if hasattr(lyr, "self_attn") and hasattr(lyr.self_attn, "layer_idx"):
|
||||
lyr.self_attn.layer_idx = i
|
||||
```
|
||||
ผลลัพธ์: 3.09B -> **1.70B** (ยังไม่ถึง 1B เป๊ะ เพราะ embedding+lm_head+vocab ไม่ลดตาม layer)
|
||||
|
||||
> หลัง prune โมเดลจะพ่น gibberish ทันที (เส้นประสาทถูกตัดขาด) -> ต้อง Healing ต่อ
|
||||
|
||||
### ขั้นที่ 2 — Healing SFT
|
||||
เทรนต่อด้วย causal-LM บน Thai corpus เพื่อให้ layer ที่เหลือกลับมาทำงานร่วมกัน
|
||||
|
||||
```python
|
||||
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("SPAISS6F1/spai-ss6-llm-1b-thai-corpus", split="train") # หรือ SEA-PILE v2 'th'
|
||||
tds = ds.map(lambda e: tok(e["text"], truncation=True, max_length=1024),
|
||||
batched=True, remove_columns=ds.column_names)
|
||||
|
||||
model.gradient_checkpointing_enable(); model.config.use_cache = False
|
||||
args = TrainingArguments(output_dir="out", num_train_epochs=2,
|
||||
per_device_train_batch_size=4, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, lr_scheduler_type="cosine", warmup_ratio=0.03, bf16=True)
|
||||
Trainer(model=model, args=args, train_dataset=tds,
|
||||
data_collator=DataCollatorForLanguageModeling(tok, mlm=False)).train()
|
||||
```
|
||||
|
||||
**Hyperparameters:**
|
||||
- Learning rate: `1e-4` (สูงกว่าปกติเพื่อสมานแผล) | Epochs: 2
|
||||
- Batch 4 x grad-accum 4 (effective 16) | max_len 1024 | bf16
|
||||
- Optimizer: AdamW + cosine schedule, warmup 3%
|
||||
- Env: conda myenv (transformers 4.49)
|
||||
|
||||
### ขั้นที่ 3 — Save
|
||||
```python
|
||||
model.config.use_cache = True
|
||||
try:
|
||||
model.save_pretrained("out", safe_serialization=True)
|
||||
except RuntimeError: # Gemma3: tied embeddings -> fallback .bin
|
||||
model.save_pretrained("out", safe_serialization=False)
|
||||
```
|
||||
โมเดลนี้ save เป็น: `model.safetensors`
|
||||
|
||||
---
|
||||
|
||||
## วิธีใช้ (Inference)
|
||||
```python
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
m = "SPAISS6F1/qwen-1b-pruned-th"
|
||||
tok = AutoTokenizer.from_pretrained(m)
|
||||
model = AutoModelForCausalLM.from_pretrained(m, torch_dtype=torch.bfloat16, device_map="cuda")
|
||||
ids = tok("ปัญญาประดิษฐ์ คือ", return_tensors="pt").to(model.device)
|
||||
out = model.generate(**ids, max_new_tokens=120, do_sample=True,
|
||||
temperature=0.7, top_p=0.9, repetition_penalty=1.3)
|
||||
print(tok.decode(out[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## ข้อควรรู้ / ข้อจำกัด
|
||||
- เป็น **pruned base ที่ heal ด้วย raw web corpus** -> ไวยากรณ์ไทยลื่นไหลดี
|
||||
แต่ **ข้อเท็จจริงและการคิดเลขยังอ่อน** (ยังไม่ผ่าน instruction tuning)
|
||||
- แนะนำ `repetition_penalty >= 1.2` กันการวนซ้ำ
|
||||
- เหมาะเป็น **base สำหรับ fine-tune ต่อด้วย instruction dataset** มากกว่าใช้ตอบตรง ๆ
|
||||
- การตัด layer 50% เป็นการตัดที่ค่อนข้างหนัก (งานวิจัย เช่น ShortGPT แนะ ~25%);
|
||||
ถ้าต้องการคุณภาพสูงขึ้นควร heal นานขึ้น/ตัดเบาลง
|
||||
Reference in New Issue
Block a user