110 lines
3.1 KiB
Markdown
110 lines
3.1 KiB
Markdown
|
|
---
|
||
|
|
license: mit
|
||
|
|
language:
|
||
|
|
- en
|
||
|
|
base_model: Qwen/Qwen2.5-7B-Instruct
|
||
|
|
tags:
|
||
|
|
- reinforcement-learning
|
||
|
|
- grpo
|
||
|
|
- tool-use
|
||
|
|
- debugging
|
||
|
|
- dsl
|
||
|
|
datasets:
|
||
|
|
- custom
|
||
|
|
pipeline_tag: text-generation
|
||
|
|
model-index:
|
||
|
|
- name: dsl-debug-7b-sft-rl
|
||
|
|
results:
|
||
|
|
- task:
|
||
|
|
type: text-generation
|
||
|
|
name: DSL Debugging (Standard)
|
||
|
|
metrics:
|
||
|
|
- type: accuracy
|
||
|
|
value: 86.1
|
||
|
|
name: Standard Bugs (481)
|
||
|
|
- task:
|
||
|
|
type: text-generation
|
||
|
|
name: DSL Debugging (Nonlocal)
|
||
|
|
metrics:
|
||
|
|
- type: accuracy
|
||
|
|
value: 70.5
|
||
|
|
name: Nonlocal Bugs (200)
|
||
|
|
- task:
|
||
|
|
type: text-generation
|
||
|
|
name: DSL Debugging (Intent-Mismatch)
|
||
|
|
metrics:
|
||
|
|
- type: accuracy
|
||
|
|
value: 28.2
|
||
|
|
name: Intent-Mismatch Bugs (177)
|
||
|
|
---
|
||
|
|
|
||
|
|
# DSL Debug 7B: SFT then RL (Best Checkpoint)
|
||
|
|
|
||
|
|
Qwen2.5-7B-Instruct fine-tuned with SFT then GRPO reinforcement learning to debug programs in a custom dataflow DSL. This is the best-performing checkpoint (step 35 of 40 RL steps).
|
||
|
|
|
||
|
|
**Blog post:** [Multi-Turn RL for Code Debugging](https://andrewlngdn.github.io/dsl_debugger/)
|
||
|
|
**Code + environment:** [github.com/AndrewLngdn/dsl-debug](https://github.com/AndrewLngdn/dsl-debug)
|
||
|
|
|
||
|
|
## Results
|
||
|
|
|
||
|
|
Held-out test set, one-shot evaluation:
|
||
|
|
|
||
|
|
| Method | Standard (481) | Nonlocal (200) | Intent-Mismatch (177) |
|
||
|
|
|--------|:-:|:-:|:-:|
|
||
|
|
| Prompt Engineering (base) | 50.5% | 12.0% | 0.6% |
|
||
|
|
| SFT (step 100) | 56.3% | 40.0% | 7.9% |
|
||
|
|
| RL-only (step 30) | 78.8% | 54.0% | 14.7% |
|
||
|
|
| **SFT then RL (this model)** | **86.1%** | **70.5%** | **28.2%** |
|
||
|
|
|
||
|
|
### Alignment Tax
|
||
|
|
|
||
|
|
| Benchmark | Base | This Model |
|
||
|
|
|-----------|------|-----------|
|
||
|
|
| MMLU (5-shot) | 74.6% | 74.5% |
|
||
|
|
| GSM8K (8-shot) | 84.9% | 84.1% |
|
||
|
|
| HumanEval (0-shot) | 65.9% | 62.2% |
|
||
|
|
|
||
|
|
## Training
|
||
|
|
|
||
|
|
Two-stage training on 2x A100-80GB using [verl](https://github.com/volcengine/verl) 0.7:
|
||
|
|
|
||
|
|
1. **SFT**: 1,593 expert trajectories from GPT-5-mini, full parameter updates, LR=5e-6, 2 epochs (step 100)
|
||
|
|
2. **GRPO**: 6,420 RL problems, LR=1e-5 cosine, batch 512 prompts x 8 rollouts, 40 steps, no KL penalty
|
||
|
|
|
||
|
|
## The Task
|
||
|
|
|
||
|
|
The model debugs programs in a custom pipe-based dataflow DSL. Each episode provides buggy code and expected output. The model has 8 turns and 4 tools:
|
||
|
|
|
||
|
|
- `run(code)`: Execute DSL code
|
||
|
|
- `inspect(node_name)`: View intermediate table output
|
||
|
|
- `read_docs(operation)`: Read DSL documentation
|
||
|
|
- `submit(code)`: Submit fix (ends episode, binary reward)
|
||
|
|
|
||
|
|
## Usage
|
||
|
|
|
||
|
|
```python
|
||
|
|
# With sglang
|
||
|
|
from sglang import RuntimeEndpoint
|
||
|
|
import sglang as sgl
|
||
|
|
|
||
|
|
runtime = RuntimeEndpoint("http://localhost:30000")
|
||
|
|
|
||
|
|
# Or download and serve
|
||
|
|
from huggingface_hub import snapshot_download
|
||
|
|
snapshot_download("andrewlngdn/dsl-debug-7b-sft-rl", local_dir="./model")
|
||
|
|
```
|
||
|
|
|
||
|
|
```bash
|
||
|
|
# Using the dsl-debug CLI
|
||
|
|
pip install dsl-debug
|
||
|
|
dsl-debug sglang # downloads and serves this model
|
||
|
|
dsl-debug eval --split standard # evaluate on test set
|
||
|
|
```
|
||
|
|
|
||
|
|
## Related Models
|
||
|
|
|
||
|
|
| Model | Repo |
|
||
|
|
|-------|------|
|
||
|
|
| SFT step 100 | [andrewlngdn/dsl-debug-7b-sft-step100](https://huggingface.co/andrewlngdn/dsl-debug-7b-sft-step100) |
|
||
|
|
| RL-only step 30 | [andrewlngdn/dsl-debug-7b-rl-only-step30](https://huggingface.co/andrewlngdn/dsl-debug-7b-rl-only-step30) |
|