Files
llama-3.1-8b-bib-grounded-s…/handler.py

315 lines
16 KiB
Python
Raw Normal View History

import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import AutoPeftModelForCausalLM
DEFAULT_INSTRUCTION = (
"You are an expert research assistant for the Born in Bradford study. "
"Use only the provided context. "
"If the context contains the requested value, relationship, or finding, answer it directly; do not say it is absent. "
"Only answer exactly: The provided context does not contain the requested information after checking paper excerpts, labels, headings, rows, and adjacent context for an exact match. "
"For table-like or multi-finding context, choose the answer whose row label, column label, outcome, subgroup, timepoint, model, and measurement all match the question; ignore nearby rows where any anchor differs. "
"Follow the response rule after the question, because it identifies the exact extraction style needed. "
"Preserve all numbers, ranges, units, p-values, confidence intervals, dates, "
"distances, comparison groups, and categories exactly as written in the context. "
"Do not add page numbers, figure numbers, table positions, p-values, methods, "
"ranges, or qualifiers unless they appear in the provided context. "
"Answer only the requested value, relationship, or finding; do not include "
"adjacent table values, extra categories, variable metadata, or explanatory detail unless asked. "
"Do not add partial related information after abstaining."
)
class EndpointHandler:
def __init__(self, path: str = ""):
model_dir = path or "/repository"
self.tokenizer = AutoTokenizer.from_pretrained(
model_dir,
trust_remote_code=True,
)
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.truncation_side = "left"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
adapter_config_path = os.path.join(model_dir, "adapter_config.json")
if os.path.exists(adapter_config_path):
self.model = AutoPeftModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
else:
self.model = AutoModelForCausalLM.from_pretrained(
model_dir,
trust_remote_code=True,
torch_dtype=dtype,
low_cpu_mem_usage=True,
device_map="auto" if torch.cuda.is_available() else None,
)
self.model.eval()
def _coerce_messages(self, raw_messages):
messages = []
for message in raw_messages:
if not isinstance(message, dict):
continue
role = str(message.get("role", "")).strip()
content = message.get("content", "")
if role not in {"system", "user", "assistant"}:
continue
if content is None:
content = ""
messages.append({"role": role, "content": str(content)})
return messages
def _merge_system_into_first_user(self, messages):
system_text = "\n\n".join(
message["content"] for message in messages if message["role"] == "system"
).strip()
non_system_messages = [message for message in messages if message["role"] != "system"]
if not system_text:
return non_system_messages
if not non_system_messages:
return [{"role": "user", "content": f"Instruction: {system_text}"}]
for index, message in enumerate(non_system_messages):
if message["role"] == "user":
non_system_messages[index] = {
"role": "user",
"content": f"Instruction: {system_text}\n\n{message['content']}",
}
return non_system_messages
return [{"role": "user", "content": f"Instruction: {system_text}"}] + non_system_messages
def _extract_question_from_content(self, content):
text = str(content or "")
marker = None
for candidate in ("Researcher question:", "Question:"):
if candidate in text:
marker = candidate
if marker is None:
return ""
question = text.rsplit(marker, 1)[1]
question = question.split("\nAnswer:", 1)[0]
question = question.split("\n\nResponse rule:", 1)[0]
return question.strip()
def _response_rule(self, question, context=""):
question_text = str(question or "")
context_text = str(context or "")
lowered = question_text.lower()
relationship_terms = (
"relationship", "association", "associated", "correlation",
"linked", "impact", "effect", "finding", "according to the study",
)
direct_numeric_terms = (
"odds ratio", " percentage", "percent", "proportion", "count", "number",
"value", "values", "cutoff", "cut-off", "coefficient", "confidence interval",
"ci", "p-value", "sample size", "mean", "median", "coded", "code",
"classified", "unit",
)
paper_terms = ("paper", "study", "article", "published", "according to the study", "according to the paper", "systematic review", "meta-analysis")
metadata_terms = ("variable", "table", "dataset", "data table", "field", "column", "label", "coded", "code", "non-missing", "records", "entities", "rows")
context_has_papers = "## Relevant Published Papers" in context_text or "Source: full-text PDF" in context_text
context_has_metadata = "## Relevant Variables" in context_text or "## Relevant Tables" in context_text
asks_paper_style = any(term in lowered for term in paper_terms) and not any(term in lowered for term in metadata_terms)
source_priority_rule = ""
if context_has_papers and context_has_metadata and asks_paper_style:
source_priority_rule = (
"For published-paper questions, answer from the Relevant Published Papers/full-text PDF text. "
"Ignore Relevant Variables and Relevant Tables metadata unless the question explicitly asks for variables, tables, columns, labels, coding, or non-missing records. "
)
presence_rule = (
"Positive evidence has priority over abstention. "
"If any sentence, table row, header-expanded row, or labelled excerpt contains the requested value, relationship, or finding, answer it directly; do not abstain. "
"Before abstaining, check repeated headers, row labels, timepoints, outcomes, comparison groups, paper excerpts, and adjacent context for an exact match. "
"Do not use the abstention answer when a matching answer is present under the requested label, source, study, or outcome. "
)
abstention_rule = (
"Only if the exact requested value, comparison, timepoint, work package, criterion, source, funding body, or finding is not present after that check, "
"answer exactly: \"The provided context does not contain the requested information.\" and stop. "
"Do not add partial related information after abstaining. "
"Do not infer, calculate, or substitute a nearby value from another outcome, age, time period, work package, exposure, table row, cited study, or metadata section. "
)
anchor_match_rule = (
"Wrong-neighbor rule: identify the required anchors in the question, including outcome, measure, subgroup, cohort, model, timepoint, emotion, frequency, exposure, and comparison groups. "
"Use only evidence where all requested anchors match. "
"Reject nearby rows or sentences if they are about a different outcome, measure, group, model, timepoint, emotion, frequency, exposure, or comparison, even when their numbers look plausible. "
"If a matching row contains multiple requested groups, return all requested values from that row; do not omit one group. "
)
asks_for_relationship = any(term in lowered for term in relationship_terms)
asks_for_direct_number = any(term in lowered for term in direct_numeric_terms) or " or " in f" {lowered} "
if asks_for_relationship and not asks_for_direct_number:
return (
"Response rule: " + source_priority_rule + presence_rule + anchor_match_rule +
"State the requested relationship or finding in words using only the context. "
"Match the requested exposure, outcome, population, cited study, and timepoint exactly. "
"If the exact relationship statement is present, restate it rather than abstaining, even if the context also contains adjacent unrelated findings. "
"Keep numbers that identify requested groups or timepoints, but do not add model coefficients, "
"confidence intervals, p-values, table/figure/page references, variable names, or adjacent numeric details unless the question asks for them. "
+ abstention_rule
)
numeric_terms = (
"odds ratio", "percentage", "percent", "proportion", "count", "number",
"value", "values", "cutoff", "cut-off", "coefficient", "confidence interval",
"ci", "p-value", "sample size", "mean", "median", "bmi", "z-score",
"quintile", "buffer", "year", "month", "date", "coded", "code",
"classified", "unit",
)
is_numeric_like = any(char.isdigit() for char in question_text) or any(
term in lowered for term in numeric_terms
)
if is_numeric_like:
return (
"Response rule: " + source_priority_rule + presence_rule + anchor_match_rule +
"Extract only the value(s) requested by the question. "
"Use nearby headers and labels to choose the matching column, outcome, timepoint, comparison group, or row; do not default to the first number in a row. "
"Prefer the value whose row and column labels both match the question. If that matching value is present, return it rather than abstaining. "
"Match the requested age, time period, subgroup, exposure, outcome, and direction exactly before copying a number. "
"Copy requested numbers, ranges, units, comparison groups, and categories exactly from the context. "
"If the question asks for multiple groups or contrasts two reported values, include each requested group. "
"Do not include adjacent rows, unrelated categories, p-values, confidence intervals, variable names, "
"table/figure/page references, or explanations unless the question asks for them. "
+ abstention_rule
)
return (
"Response rule: " + source_priority_rule +
"Answer directly and concisely using only the context. "
"Match the requested paper, study aim, source, criterion, work package, population, exposure, and outcome exactly. "
"Do not add table, figure, page, row, variable, or section references unless the question asks for them. "
+ abstention_rule
)
def _append_response_rule_to_last_user(self, messages):
updated = list(messages)
for index in range(len(updated) - 1, -1, -1):
message = updated[index]
if message.get("role") != "user":
continue
content = message.get("content", "")
if "Response rule:" in content:
return updated
question = self._extract_question_from_content(content)
if question:
updated[index] = {
"role": "user",
"content": f"{content}\n\n{self._response_rule(question, content)}",
}
return updated
return updated
def _build_messages(self, inputs, params):
use_system_role = bool(params.get("use_system_role", False))
if isinstance(inputs, list):
messages = self._coerce_messages(inputs)
if not messages:
return [{"role": "user", "content": ""}]
if use_system_role:
return self._append_response_rule_to_last_user(messages)
return self._append_response_rule_to_last_user(self._merge_system_into_first_user(messages))
if isinstance(inputs, dict) and "messages" in inputs:
return self._build_messages(inputs["messages"], params)
instruction = DEFAULT_INSTRUCTION
if isinstance(inputs, dict) and inputs.get("system"):
instruction = str(inputs["system"])
if isinstance(inputs, dict) and "context" in inputs and "question" in inputs:
question = str(inputs["question"])
context = str(inputs["context"])
response_rule = self._response_rule(question, context)
if use_system_role:
return [
{"role": "system", "content": instruction},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}\n\n{response_rule}"},
]
return [
{
"role": "user",
"content": f"Instruction: {instruction}\n\nContext:\n{context}\n\nQuestion: {question}\n\n{response_rule}",
}
]
prompt_text = str(inputs.get("prompt", "")) if isinstance(inputs, dict) else str(inputs)
if use_system_role:
return self._append_response_rule_to_last_user([
{"role": "system", "content": instruction},
{"role": "user", "content": prompt_text},
])
return self._append_response_rule_to_last_user([
{"role": "user", "content": f"Instruction: {instruction}\n\n{prompt_text}"},
])
def __call__(self, data):
inputs = data.get("inputs", "")
params = data.get("parameters", {}) or {}
max_new_tokens = min(int(params.get("max_new_tokens", 64)), 512)
max_input_tokens = int(params.get("max_input_tokens", 4096))
max_input_tokens = max(512, min(max_input_tokens, 8192))
do_sample = bool(params.get("do_sample", False))
temperature = float(params.get("temperature", 0.7 if do_sample else 0.0))
top_p = float(params.get("top_p", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
no_repeat_ngram_size = int(params.get("no_repeat_ngram_size", 0))
debug = bool(params.get("debug", False))
messages = self._build_messages(inputs, params)
prompt = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
enc = self.tokenizer(
prompt,
return_tensors="pt",
add_special_tokens=False,
truncation=True,
max_length=max_input_tokens,
)
truncated = enc["input_ids"].shape[-1] >= max_input_tokens
if torch.cuda.is_available():
enc = {key: value.to(self.model.device) for key, value in enc.items()}
eos_token_id = getattr(self.model.config, "eos_token_id", None)
if eos_token_id is None:
eos_token_id = self.tokenizer.eos_token_id
generate_kwargs = dict(
**enc,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
repetition_penalty=repetition_penalty,
renormalize_logits=True,
pad_token_id=self.tokenizer.pad_token_id,
eos_token_id=eos_token_id,
)
if do_sample:
generate_kwargs["temperature"] = max(temperature, 1e-5)
generate_kwargs["top_p"] = top_p
if no_repeat_ngram_size > 0:
generate_kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size
with torch.no_grad():
out = self.model.generate(**generate_kwargs)
generated_ids = out[0][enc["input_ids"].shape[-1]:]
text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
response = {"generated_text": text}
if debug:
response["prompt"] = prompt
response["messages"] = messages
response["input_tokens"] = int(enc["input_ids"].shape[-1])
response["truncated"] = bool(truncated)
return response