初始化项目,由ModelHub XC社区提供模型
Model: andrewlngdn/dsl-debug-7b-sft-rl Source: Original Platform
This commit is contained in:
109
README.md
Normal file
109
README.md
Normal file
@@ -0,0 +1,109 @@
|
||||
---
|
||||
license: mit
|
||||
language:
|
||||
- en
|
||||
base_model: Qwen/Qwen2.5-7B-Instruct
|
||||
tags:
|
||||
- reinforcement-learning
|
||||
- grpo
|
||||
- tool-use
|
||||
- debugging
|
||||
- dsl
|
||||
datasets:
|
||||
- custom
|
||||
pipeline_tag: text-generation
|
||||
model-index:
|
||||
- name: dsl-debug-7b-sft-rl
|
||||
results:
|
||||
- task:
|
||||
type: text-generation
|
||||
name: DSL Debugging (Standard)
|
||||
metrics:
|
||||
- type: accuracy
|
||||
value: 86.1
|
||||
name: Standard Bugs (481)
|
||||
- task:
|
||||
type: text-generation
|
||||
name: DSL Debugging (Nonlocal)
|
||||
metrics:
|
||||
- type: accuracy
|
||||
value: 70.5
|
||||
name: Nonlocal Bugs (200)
|
||||
- task:
|
||||
type: text-generation
|
||||
name: DSL Debugging (Intent-Mismatch)
|
||||
metrics:
|
||||
- type: accuracy
|
||||
value: 28.2
|
||||
name: Intent-Mismatch Bugs (177)
|
||||
---
|
||||
|
||||
# DSL Debug 7B: SFT then RL (Best Checkpoint)
|
||||
|
||||
Qwen2.5-7B-Instruct fine-tuned with SFT then GRPO reinforcement learning to debug programs in a custom dataflow DSL. This is the best-performing checkpoint (step 35 of 40 RL steps).
|
||||
|
||||
**Blog post:** [Multi-Turn RL for Code Debugging](https://andrewlngdn.github.io/dsl_debugger/)
|
||||
**Code + environment:** [github.com/AndrewLngdn/dsl-debug](https://github.com/AndrewLngdn/dsl-debug)
|
||||
|
||||
## Results
|
||||
|
||||
Held-out test set, one-shot evaluation:
|
||||
|
||||
| Method | Standard (481) | Nonlocal (200) | Intent-Mismatch (177) |
|
||||
|--------|:-:|:-:|:-:|
|
||||
| Prompt Engineering (base) | 50.5% | 12.0% | 0.6% |
|
||||
| SFT (step 100) | 56.3% | 40.0% | 7.9% |
|
||||
| RL-only (step 30) | 78.8% | 54.0% | 14.7% |
|
||||
| **SFT then RL (this model)** | **86.1%** | **70.5%** | **28.2%** |
|
||||
|
||||
### Alignment Tax
|
||||
|
||||
| Benchmark | Base | This Model |
|
||||
|-----------|------|-----------|
|
||||
| MMLU (5-shot) | 74.6% | 74.5% |
|
||||
| GSM8K (8-shot) | 84.9% | 84.1% |
|
||||
| HumanEval (0-shot) | 65.9% | 62.2% |
|
||||
|
||||
## Training
|
||||
|
||||
Two-stage training on 2x A100-80GB using [verl](https://github.com/volcengine/verl) 0.7:
|
||||
|
||||
1. **SFT**: 1,593 expert trajectories from GPT-5-mini, full parameter updates, LR=5e-6, 2 epochs (step 100)
|
||||
2. **GRPO**: 6,420 RL problems, LR=1e-5 cosine, batch 512 prompts x 8 rollouts, 40 steps, no KL penalty
|
||||
|
||||
## The Task
|
||||
|
||||
The model debugs programs in a custom pipe-based dataflow DSL. Each episode provides buggy code and expected output. The model has 8 turns and 4 tools:
|
||||
|
||||
- `run(code)`: Execute DSL code
|
||||
- `inspect(node_name)`: View intermediate table output
|
||||
- `read_docs(operation)`: Read DSL documentation
|
||||
- `submit(code)`: Submit fix (ends episode, binary reward)
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# With sglang
|
||||
from sglang import RuntimeEndpoint
|
||||
import sglang as sgl
|
||||
|
||||
runtime = RuntimeEndpoint("http://localhost:30000")
|
||||
|
||||
# Or download and serve
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download("andrewlngdn/dsl-debug-7b-sft-rl", local_dir="./model")
|
||||
```
|
||||
|
||||
```bash
|
||||
# Using the dsl-debug CLI
|
||||
pip install dsl-debug
|
||||
dsl-debug sglang # downloads and serves this model
|
||||
dsl-debug eval --split standard # evaluate on test set
|
||||
```
|
||||
|
||||
## Related Models
|
||||
|
||||
| Model | Repo |
|
||||
|-------|------|
|
||||
| SFT step 100 | [andrewlngdn/dsl-debug-7b-sft-step100](https://huggingface.co/andrewlngdn/dsl-debug-7b-sft-step100) |
|
||||
| RL-only step 30 | [andrewlngdn/dsl-debug-7b-rl-only-step30](https://huggingface.co/andrewlngdn/dsl-debug-7b-rl-only-step30) |
|
||||
Reference in New Issue
Block a user