Files
semantic-turn-taking/training_config.yaml
ModelHub XC a288e309b5 初始化项目,由ModelHub XC社区提供模型
Model: anyreach-ai/semantic-turn-taking
Source: Original Platform
2026-04-29 12:06:38 +08:00

79 lines
1.7 KiB
YAML

# Training configuration for checkpoint-7000 (synth-only-v2-fixed-run2-removedlabelaug)
# This is the config used to train the published model.
model:
name: "Qwen/Qwen2.5-0.5B-Instruct"
max_length: 1024
predict_token: "<|predict|>"
action_tokens:
- "<|continue_listening|>"
- "<|start_speaking|>"
- "<|start_listening|>"
- "<|continue_speaking|>"
prediction_tokens:
- "<|continue_listening|>"
- "<|start_speaking|>"
- "<|start_listening|>"
- "<|continue_speaking|>"
data:
train:
- "data/context_action/synthetic/train.csv" # ~140K
- "data/context_action/synthetic/supplementary_start_listening.csv" # ~14K
val: "data/context_action/synthetic/val.csv"
test: "data/context_action/synthetic/test.csv"
training:
batch_size: 8
gradient_accumulation_steps: 4 # effective batch size = 32
learning_rate: 5.0e-5
num_epochs: 50
warmup_steps: 100
weight_decay: 0.01
max_grad_norm: 1.0
bf16: true
lr_scheduler: "cosine"
loss:
ntp_weight: 0.1 # 90% action CE + 10% NTP auxiliary
early_stopping:
enabled: true
patience: 10
metric: "eval_f1_macro"
augmentation:
enabled: true
context_truncation:
enabled: true
probability: 0.2
min_turns: 1
asr_styles:
- "pure"
- "punctuated"
- "mixed"
filler_injection:
enabled: true
probability: 0.2
max_fillers: 3
disfluency:
enabled: true
probability: 0.2
# Label-changing augmentations DISABLED (broke balanced sampling)
streaming_crop:
enabled: false
backchannel_inject:
enabled: false
output:
dir: "models_longer_training"
save_steps: 1000
eval_steps: 1000
logging_steps: 100
save_total_limit: 5
evaluation:
batch_size: 8