146 lines
5.0 KiB
Markdown
146 lines
5.0 KiB
Markdown
---
|
|
library_name: transformers
|
|
license: apache-2.0
|
|
base_model: Qwen/Qwen3-4B-Instruct-2507
|
|
tags:
|
|
- tool-use
|
|
- multi-turn
|
|
- agents
|
|
- grpo
|
|
- reinforcement-learning
|
|
- tau2-bench
|
|
datasets:
|
|
- Jarrodbarnes/tau2-sft-seed-v3
|
|
pipeline_tag: text-generation
|
|
model-index:
|
|
- name: Qwen3-4B-tau2-grpo-v1
|
|
results:
|
|
- task:
|
|
type: multi-turn-tool-use
|
|
name: tau2-bench
|
|
dataset:
|
|
type: tau2-bench
|
|
name: tau2-bench (test split)
|
|
metrics:
|
|
- type: pass@1
|
|
value: 36.0
|
|
name: Pass@1 (Overall)
|
|
- type: pass@4
|
|
value: 59.0
|
|
name: Pass@4 (Overall)
|
|
---
|
|
|
|
# Qwen3-4B-tau2-grpo-v1
|
|
|
|
A 4B parameter model fine-tuned for multi-turn tool-use tasks, achieving **59% Pass@4** on tau2-bench (test split). This represents a **4x improvement** over the base model and demonstrates that progressive training (SFT -> RFT -> GRPO) works effectively for complex, multi-turn agent tasks.
|
|
|
|
## Model Description
|
|
|
|
This model was trained using a three-stage pipeline:
|
|
1. **SFT (Supervised Fine-Tuning)**: Learning protocol and tool schemas from successful trajectories
|
|
2. **RFT (Rejection Fine-Tuning)**: Concentrating training on high-quality rollouts via rejection sampling
|
|
3. **GRPO (Group Relative Policy Optimization)**: Reinforcement learning with turn-level reward shaping
|
|
|
|
The training process is documented in the [tau2 training cookbook](https://github.com/THUDM/slime/blob/main/examples/tau-bench/training_cookbook.md).
|
|
|
|
## Performance
|
|
|
|
### tau2-bench Test Split (Pass@4 evaluation)
|
|
|
|
| Domain | Pass@1 | Pass@4 | Tasks |
|
|
|--------|--------|--------|-------|
|
|
| **Overall** | **36.0%** | **59.0%** | 100 |
|
|
| Airline | 15.0% | 45.0% | 20 |
|
|
| Retail | 55.0% | 85.0% | 40 |
|
|
| Telecom | 27.5% | 40.0% | 40 |
|
|
|
|
### Training Progression
|
|
|
|
| Stage | Overall Pass@4 |
|
|
|-------|----------------|
|
|
| Baseline (Qwen3-4B-Instruct) | 14.3% |
|
|
| SFT + RFT | 27.0% |
|
|
| GRPO (this model) | **59.0%** |
|
|
|
|
**Eval config**: `temperature=0.8`, `top_p=1.0`, `top_k=20`, `num_samples=4`, `TAU2_USER_MODEL=gpt-4.1-mini`, `TAU2_USER_TEMPERATURE=0.7`, `TAU2_MAX_STEPS=100`.
|
|
|
|
## Usage
|
|
|
|
### With SGLang (recommended for evaluation)
|
|
|
|
```bash
|
|
# Start the server (use --tp 1 for single GPU)
|
|
python -m sglang.launch_server \
|
|
--model-path Jarrodbarnes/Qwen3-4B-tau2-grpo-v1 \
|
|
--host 0.0.0.0 --port 30000 --tp 2 --mem-fraction-static 0.70
|
|
```
|
|
|
|
### With Transformers
|
|
|
|
```python
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
model_name = "Jarrodbarnes/Qwen3-4B-tau2-grpo-v1"
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
model_name,
|
|
torch_dtype="auto",
|
|
device_map="auto"
|
|
)
|
|
```
|
|
|
|
### Function Calling Format
|
|
|
|
This model uses Qwen3 native function calling format:
|
|
|
|
```
|
|
<tool_call>{"name": "tool_name", "arguments": {"arg": "value"}}</tool_call>
|
|
```
|
|
|
|
Include `</tool_call>` in stop sequences for proper parsing.
|
|
|
|
## Training Details
|
|
|
|
- **Base model**: [Qwen/Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507)
|
|
- **Training framework**: [SLIME](https://github.com/THUDM/slime) (Megatron-LM + SGLang)
|
|
- **SFT data**: [tau2-sft-seed-v3](https://huggingface.co/datasets/Jarrodbarnes/tau2-sft-seed-v3)
|
|
- **GRPO steps**: 21 optimizer steps
|
|
- **Reward shaping**: Turn-level partial scores from tau2-bench reward_info
|
|
- **User simulator (training)**: Local Qwen3-4B-Instruct on port 30001
|
|
- **User simulator (eval)**: GPT-4.1-mini via OpenAI API
|
|
|
|
### W&B Training Logs
|
|
|
|
- SFT run: [b7d80rfe](https://wandb.ai/jbarnes850-near-protocol/tau2-cookbook/runs/b7d80rfe)
|
|
- GRPO run: [pkeu9kck](https://wandb.ai/jbarnes850-near-protocol/tau2-cookbook/runs/pkeu9kck)
|
|
|
|
## Resources
|
|
|
|
- [Training Cookbook](https://github.com/THUDM/slime/blob/main/examples/tau-bench/training_cookbook.md) - Full methodology and reproduction steps
|
|
- [SFT Checkpoint](https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-sft1) - Intermediate SFT+RFT checkpoint
|
|
- [Training Dataset](https://huggingface.co/datasets/Jarrodbarnes/tau2-sft-seed-v3) - Filtered RFT trajectories
|
|
- [tau2-bench](https://github.com/sierra-research/tau2-bench) - Benchmark repository
|
|
|
|
## Limitations
|
|
|
|
- **Telecom domain**: Dual-control tasks (where the agent must instruct rather than execute) remain challenging (40% Pass@4)
|
|
- **User simulator sensitivity**: Results vary with user simulator choice; GPT-4.1-mini recommended for reproducibility
|
|
- **Pass@k vs Pass^k**: This model reports pass@k (any success in k attempts), not the pass^k metric used on the official tau2-bench leaderboard
|
|
|
|
## Citation
|
|
|
|
```bibtex
|
|
@misc{qwen3-tau2-grpo,
|
|
title={Qwen3-4B-tau2-grpo-v1: Multi-Turn Tool-Use Agent via Progressive RL Training},
|
|
author={Jarrod Barnes},
|
|
year={2025},
|
|
url={https://huggingface.co/Jarrodbarnes/Qwen3-4B-tau2-grpo-v1}
|
|
}
|
|
```
|
|
|
|
## Acknowledgments
|
|
|
|
- [Qwen Team](https://github.com/QwenLM/Qwen3) for the base model
|
|
- [Sierra Research](https://github.com/sierra-research/tau2-bench) for tau2-bench
|
|
- [THUDM](https://github.com/THUDM/slime) for the SLIME training framework
|