135 lines
4.8 KiB
Markdown
135 lines
4.8 KiB
Markdown
---
|
||
library_name: transformers
|
||
license: cc-by-nc-4.0
|
||
base_model: meta-llama/Llama-3.1-8B-Instruct
|
||
tags:
|
||
- rag
|
||
- filtering
|
||
---
|
||
|
||
|
||
## Model Description
|
||
|
||
This model is a fine-tuned version of [meta-llama/Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct), trained for 🚀**evidence relevance classification or evidence filtering**🚀 in medical RAG pipelines.
|
||
Given a clinical query and a candidate passage, the model outputs *“Yes”* if the passage contains supporting evidence and *“No”* otherwise.
|
||
|
||
This lightweight classifier is designed to help researchers:
|
||
- Improve retrieval quality in medical RAG systems.
|
||
- Filter irrelevant passages before generation.
|
||
- Build more reliable, interpretable RAG pipelines for medical QA.
|
||
|
||
For additional context, methodology, and full experimental details, please refer to our paper below.
|
||
|
||
📄 **Paper**: [Rethinking Retrieval-Augmented Generation for Medicine: A Large-Scale, Systematic Expert Evaluation and Practical Insights](https://arxiv.org/abs/2511.06738)
|
||
|
||
|
||
## Quick Start
|
||
|
||
```python
|
||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||
import torch
|
||
|
||
model_id = "Yale-BIDS-Chen/Llama-3.1-8B-Evidence-Filtering"
|
||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||
model = AutoModelForCausalLM.from_pretrained(
|
||
model_id,
|
||
torch_dtype=torch.bfloat16,
|
||
device_map="auto",
|
||
)
|
||
|
||
# Instruction used during training
|
||
INSTRUCTION = (
|
||
"Given a query and a text passage, determine whether the passage contains supporting evidence for the query. "
|
||
"Supporting evidence means that the passage provides clear, relevant, and factual information that directly backs or justifies the answer to the query.\n\n"
|
||
"Respond with one of the following labels:\n\"Yes\" if the passage contains supporting evidence for the query.\n"
|
||
"\"No\" if the passage does not contain supporting evidence.\n"
|
||
"You should respond with only the label (Yes or No) without any additional explanation."
|
||
)
|
||
|
||
# Example query + retrieved passage
|
||
query = "What is the first-line treatment for acute angle-closure glaucoma?"
|
||
doc = "Acute angle-closure glaucoma requires immediate treatment with topical beta-blockers, alpha agonists, and systemic carbonic anhydrase inhibitors."
|
||
|
||
# Build chat-style prompt
|
||
content = tokenizer.apply_chat_template(
|
||
[
|
||
{"role": "system", "content": INSTRUCTION},
|
||
{"role": "user", "content": f"Question: {query}\nPassage: {doc}"}
|
||
],
|
||
add_generation_prompt=True,
|
||
tokenize=False,
|
||
)
|
||
|
||
# Tokenize
|
||
input_ids = tokenizer(content, return_tensors="pt").input_ids.to(model.device)
|
||
|
||
# Define stopping tokens (Llama-3 style)
|
||
terminators = [
|
||
tokenizer.eos_token_id,
|
||
tokenizer.convert_tokens_to_ids("<|eot_id|>")
|
||
]
|
||
|
||
# Generate evidence-filtering judgment
|
||
outputs = model.generate(
|
||
input_ids=input_ids,
|
||
max_new_tokens=256,
|
||
eos_token_id=terminators,
|
||
do_sample=False,
|
||
temperature=0.0,
|
||
)
|
||
|
||
# Decode model response
|
||
response = outputs[0][input_ids.shape[-1]:]
|
||
print(tokenizer.decode(response, skip_special_tokens=True))
|
||
```
|
||
|
||
|
||
## Training Setup
|
||
|
||
- **Dataset:** 3,200 query–passage pairs with expert-provided Yes/No labels (dataset to be released in a future update).
|
||
- **Task:** Given a query and a candidate passage, the model generates *"Yes"* if the passage contains supporting evidence and *"No"* otherwise.
|
||
- **Objective:** Causal language modeling (cross-entropy next-token loss).
|
||
- **Prompt:** See the *Quick Start* section for an example usage prompt.
|
||
- **Hyperparameter Tuning:** Five-fold cross-validation.
|
||
- **Final Hyperparameters:**
|
||
- Learning rate: 2e-6
|
||
- Batch size: 8
|
||
- Epochs: 3
|
||
- **Training Framework:** [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory).
|
||
|
||
|
||
## Performance
|
||
|
||
Evaluation was conducted on 3,200 expert-annotated query–passage pairs using five-fold cross-validation.
|
||
|
||
| Model | Precision | Recall | F1 |
|
||
|-------------------------------------|-----------|--------|------|
|
||
| **Llama-3.1-8B (zero-shot)** | 0.483 | 0.566 | 0.521 |
|
||
| **GPT-4o (zero-shot)** | 0.697 | 0.324 | 0.442 |
|
||
| **Llama-3.1-8B (fine-tuned, ours)** | **0.592** | **0.657** | **0.623** |
|
||
|
||
🔥 Fine-tuning yields substantial gains over all zero-shot baselines.
|
||
|
||
|
||
## Intended Use
|
||
|
||
This model is intended for research purposes only.
|
||
|
||
|
||
## Reference
|
||
|
||
Please see the information below to cite our paper.
|
||
```bibtex
|
||
@article{kim2025rethinking,
|
||
title={Rethinking Retrieval-Augmented Generation for Medicine: A Large-Scale, Systematic Expert Evaluation and Practical Insights},
|
||
author={Kim, Hyunjae and Sohn, Jiwoong and Gilson, Aidan and Cochran-Caggiano, Nicholas and Applebaum, Serina and Jin, Heeju and Park, Seihee and Park, Yujin and Park, Jiyeong and Choi, Seoyoung and others},
|
||
journal={arXiv preprint arXiv:2511.06738},
|
||
year={2025}
|
||
}
|
||
```
|
||
|
||
## Contact
|
||
|
||
Feel free to email `hyunjae.kim@yale.edu` if you have any questions.
|