--- license: apache-2.0 base_model: Qwen/Qwen2.5-7B-Instruct tags: - world-model - babyai - reinforcement-learning - model-based-rl - sft - lora datasets: - GGOSinon/babyai-world-model-sft language: - en pipeline_tag: text-generation --- # BabyAI World Model (Qwen2.5-7B SFT) A world model for the BabyAI grid-world environment, fine-tuned from Qwen2.5-7B-Instruct using LoRA. This model predicts the next observation and available actions given the current state and the agent's action. ## Model Details - **Base model**: Qwen2.5-7B-Instruct - **Fine-tuning**: LoRA (40.4M trainable params, 0.53% of 7.66B total), merged after training - **Training data**: [GGOSinon/babyai-world-model-sft](https://huggingface.co/datasets/GGOSinon/babyai-world-model-sft) (58K transitions, 1 epoch) - **Training time**: ~5.5 hours on 1x A100 40GB - **Final loss**: 0.023 ## Performance (Done-Detection, 102 test cases) | Model | Accuracy | Precision | Recall | FPR | |---|---|---|---|---| | Qwen2.5-7B zero-shot | 76.5% | 91.7% | 32.4% | 1.5% | | **Qwen2.5-7B SFT (this model)** | **97.1%** | **100%** | **91.2%** | **0.0%** | | Gemini 2.5 Flash zero-shot | 97.1% | 100% | 91.2% | 0.0% | ## Usage ```python from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("GGOSinon/babyai-world-model-7B-sft", torch_dtype="bfloat16").to("cuda") tokenizer = AutoTokenizer.from_pretrained("GGOSinon/babyai-world-model-7B-sft") messages = [ {"role": "system", "content": "You are a simulator for a grid-world environment called BabyAI..."}, {"role": "user", "content": "Goal: pick up the red box\n\nObservation:\n...\nAvailable actions: [...]\nAgent's action: pickup red box 1"} ] inputs = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", add_generation_prompt=True).to("cuda") output = model.generate(inputs, max_new_tokens=300, do_sample=False) print(tokenizer.decode(output[0][inputs.shape[1]:], skip_special_tokens=True)) ``` ## Output Format The model outputs in this format: ``` next observation text ["action1", "action2", ...] ``` Task completion is indicated by "The task is completed." appended to the observation text.