初始化项目,由ModelHub XC社区提供模型
Model: GGOSinon/babyai-world-model-7B-sft Source: Original Platform
This commit is contained in:
64
README.md
Normal file
64
README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
---
|
||||
license: apache-2.0
|
||||
base_model: Qwen/Qwen2.5-7B-Instruct
|
||||
tags:
|
||||
- world-model
|
||||
- babyai
|
||||
- reinforcement-learning
|
||||
- model-based-rl
|
||||
- sft
|
||||
- lora
|
||||
datasets:
|
||||
- GGOSinon/babyai-world-model-sft
|
||||
language:
|
||||
- en
|
||||
pipeline_tag: text-generation
|
||||
---
|
||||
|
||||
# BabyAI World Model (Qwen2.5-7B SFT)
|
||||
|
||||
A world model for the BabyAI grid-world environment, fine-tuned from Qwen2.5-7B-Instruct using LoRA. This model predicts the next observation and available actions given the current state and the agent's action.
|
||||
|
||||
## Model Details
|
||||
|
||||
- **Base model**: Qwen2.5-7B-Instruct
|
||||
- **Fine-tuning**: LoRA (40.4M trainable params, 0.53% of 7.66B total), merged after training
|
||||
- **Training data**: [GGOSinon/babyai-world-model-sft](https://huggingface.co/datasets/GGOSinon/babyai-world-model-sft) (58K transitions, 1 epoch)
|
||||
- **Training time**: ~5.5 hours on 1x A100 40GB
|
||||
- **Final loss**: 0.023
|
||||
|
||||
## Performance (Done-Detection, 102 test cases)
|
||||
|
||||
| Model | Accuracy | Precision | Recall | FPR |
|
||||
|---|---|---|---|---|
|
||||
| Qwen2.5-7B zero-shot | 76.5% | 91.7% | 32.4% | 1.5% |
|
||||
| **Qwen2.5-7B SFT (this model)** | **97.1%** | **100%** | **91.2%** | **0.0%** |
|
||||
| Gemini 2.5 Flash zero-shot | 97.1% | 100% | 91.2% | 0.0% |
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("GGOSinon/babyai-world-model-7B-sft", torch_dtype="bfloat16").to("cuda")
|
||||
tokenizer = AutoTokenizer.from_pretrained("GGOSinon/babyai-world-model-7B-sft")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a simulator for a grid-world environment called BabyAI..."},
|
||||
{"role": "user", "content": "Goal: pick up the red box\n\nObservation:\n...\nAvailable actions: [...]\nAgent's action: pickup red box 1"}
|
||||
]
|
||||
|
||||
inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True).to("cuda")
|
||||
output = model.generate(inputs, max_new_tokens=300, do_sample=False)
|
||||
print(tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Output Format
|
||||
|
||||
The model outputs in this format:
|
||||
```
|
||||
<observation>next observation text</observation>
|
||||
<available_actions>["action1", "action2", ...]</available_actions>
|
||||
```
|
||||
|
||||
Task completion is indicated by "The task is completed." appended to the observation text.
|
||||
Reference in New Issue
Block a user