import os import torch from transformers import AutoModelForCausalLM, AutoTokenizer from peft import AutoPeftModelForCausalLM DEFAULT_INSTRUCTION = ( "You are an expert research assistant for the Born in Bradford study. " "Use only the provided context. " "If the context contains the requested value, relationship, or finding, answer it directly; do not say it is absent. " "Only answer exactly: The provided context does not contain the requested information after checking paper excerpts, labels, headings, rows, and adjacent context for an exact match. " "For table-like or multi-finding context, choose the answer whose row label, column label, outcome, subgroup, timepoint, model, and measurement all match the question; ignore nearby rows where any anchor differs. " "Follow the response rule after the question, because it identifies the exact extraction style needed. " "Preserve all numbers, ranges, units, p-values, confidence intervals, dates, " "distances, comparison groups, and categories exactly as written in the context. " "Do not add page numbers, figure numbers, table positions, p-values, methods, " "ranges, or qualifiers unless they appear in the provided context. " "Answer only the requested value, relationship, or finding; do not include " "adjacent table values, extra categories, variable metadata, or explanatory detail unless asked. " "Do not add partial related information after abstaining." ) class EndpointHandler: def __init__(self, path: str = ""): model_dir = path or "/repository" self.tokenizer = AutoTokenizer.from_pretrained( model_dir, trust_remote_code=True, ) if self.tokenizer.pad_token_id is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.truncation_side = "left" dtype = torch.float16 if torch.cuda.is_available() else torch.float32 adapter_config_path = os.path.join(model_dir, "adapter_config.json") if os.path.exists(adapter_config_path): self.model = AutoPeftModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True, device_map="auto" if torch.cuda.is_available() else None, ) else: self.model = AutoModelForCausalLM.from_pretrained( model_dir, trust_remote_code=True, torch_dtype=dtype, low_cpu_mem_usage=True, device_map="auto" if torch.cuda.is_available() else None, ) self.model.eval() def _coerce_messages(self, raw_messages): messages = [] for message in raw_messages: if not isinstance(message, dict): continue role = str(message.get("role", "")).strip() content = message.get("content", "") if role not in {"system", "user", "assistant"}: continue if content is None: content = "" messages.append({"role": role, "content": str(content)}) return messages def _merge_system_into_first_user(self, messages): system_text = "\n\n".join( message["content"] for message in messages if message["role"] == "system" ).strip() non_system_messages = [message for message in messages if message["role"] != "system"] if not system_text: return non_system_messages if not non_system_messages: return [{"role": "user", "content": f"Instruction: {system_text}"}] for index, message in enumerate(non_system_messages): if message["role"] == "user": non_system_messages[index] = { "role": "user", "content": f"Instruction: {system_text}\n\n{message['content']}", } return non_system_messages return [{"role": "user", "content": f"Instruction: {system_text}"}] + non_system_messages def _extract_question_from_content(self, content): text = str(content or "") marker = None for candidate in ("Researcher question:", "Question:"): if candidate in text: marker = candidate if marker is None: return "" question = text.rsplit(marker, 1)[1] question = question.split("\nAnswer:", 1)[0] question = question.split("\n\nResponse rule:", 1)[0] return question.strip() def _response_rule(self, question, context=""): question_text = str(question or "") context_text = str(context or "") lowered = question_text.lower() relationship_terms = ( "relationship", "association", "associated", "correlation", "linked", "impact", "effect", "finding", "according to the study", ) direct_numeric_terms = ( "odds ratio", " percentage", "percent", "proportion", "count", "number", "value", "values", "cutoff", "cut-off", "coefficient", "confidence interval", "ci", "p-value", "sample size", "mean", "median", "coded", "code", "classified", "unit", ) paper_terms = ("paper", "study", "article", "published", "according to the study", "according to the paper", "systematic review", "meta-analysis") metadata_terms = ("variable", "table", "dataset", "data table", "field", "column", "label", "coded", "code", "non-missing", "records", "entities", "rows") context_has_papers = "## Relevant Published Papers" in context_text or "Source: full-text PDF" in context_text context_has_metadata = "## Relevant Variables" in context_text or "## Relevant Tables" in context_text asks_paper_style = any(term in lowered for term in paper_terms) and not any(term in lowered for term in metadata_terms) source_priority_rule = "" if context_has_papers and context_has_metadata and asks_paper_style: source_priority_rule = ( "For published-paper questions, answer from the Relevant Published Papers/full-text PDF text. " "Ignore Relevant Variables and Relevant Tables metadata unless the question explicitly asks for variables, tables, columns, labels, coding, or non-missing records. " ) presence_rule = ( "Positive evidence has priority over abstention. " "If any sentence, table row, header-expanded row, or labelled excerpt contains the requested value, relationship, or finding, answer it directly; do not abstain. " "Before abstaining, check repeated headers, row labels, timepoints, outcomes, comparison groups, paper excerpts, and adjacent context for an exact match. " "Do not use the abstention answer when a matching answer is present under the requested label, source, study, or outcome. " ) abstention_rule = ( "Only if the exact requested value, comparison, timepoint, work package, criterion, source, funding body, or finding is not present after that check, " "answer exactly: \"The provided context does not contain the requested information.\" and stop. " "Do not add partial related information after abstaining. " "Do not infer, calculate, or substitute a nearby value from another outcome, age, time period, work package, exposure, table row, cited study, or metadata section. " ) anchor_match_rule = ( "Wrong-neighbor rule: identify the required anchors in the question, including outcome, measure, subgroup, cohort, model, timepoint, emotion, frequency, exposure, and comparison groups. " "Use only evidence where all requested anchors match. " "Reject nearby rows or sentences if they are about a different outcome, measure, group, model, timepoint, emotion, frequency, exposure, or comparison, even when their numbers look plausible. " "If a matching row contains multiple requested groups, return all requested values from that row; do not omit one group. " ) asks_for_relationship = any(term in lowered for term in relationship_terms) asks_for_direct_number = any(term in lowered for term in direct_numeric_terms) or " or " in f" {lowered} " if asks_for_relationship and not asks_for_direct_number: return ( "Response rule: " + source_priority_rule + presence_rule + anchor_match_rule + "State the requested relationship or finding in words using only the context. " "Match the requested exposure, outcome, population, cited study, and timepoint exactly. " "If the exact relationship statement is present, restate it rather than abstaining, even if the context also contains adjacent unrelated findings. " "Keep numbers that identify requested groups or timepoints, but do not add model coefficients, " "confidence intervals, p-values, table/figure/page references, variable names, or adjacent numeric details unless the question asks for them. " + abstention_rule ) numeric_terms = ( "odds ratio", "percentage", "percent", "proportion", "count", "number", "value", "values", "cutoff", "cut-off", "coefficient", "confidence interval", "ci", "p-value", "sample size", "mean", "median", "bmi", "z-score", "quintile", "buffer", "year", "month", "date", "coded", "code", "classified", "unit", ) is_numeric_like = any(char.isdigit() for char in question_text) or any( term in lowered for term in numeric_terms ) if is_numeric_like: return ( "Response rule: " + source_priority_rule + presence_rule + anchor_match_rule + "Extract only the value(s) requested by the question. " "Use nearby headers and labels to choose the matching column, outcome, timepoint, comparison group, or row; do not default to the first number in a row. " "Prefer the value whose row and column labels both match the question. If that matching value is present, return it rather than abstaining. " "Match the requested age, time period, subgroup, exposure, outcome, and direction exactly before copying a number. " "Copy requested numbers, ranges, units, comparison groups, and categories exactly from the context. " "If the question asks for multiple groups or contrasts two reported values, include each requested group. " "Do not include adjacent rows, unrelated categories, p-values, confidence intervals, variable names, " "table/figure/page references, or explanations unless the question asks for them. " + abstention_rule ) return ( "Response rule: " + source_priority_rule + "Answer directly and concisely using only the context. " "Match the requested paper, study aim, source, criterion, work package, population, exposure, and outcome exactly. " "Do not add table, figure, page, row, variable, or section references unless the question asks for them. " + abstention_rule ) def _append_response_rule_to_last_user(self, messages): updated = list(messages) for index in range(len(updated) - 1, -1, -1): message = updated[index] if message.get("role") != "user": continue content = message.get("content", "") if "Response rule:" in content: return updated question = self._extract_question_from_content(content) if question: updated[index] = { "role": "user", "content": f"{content}\n\n{self._response_rule(question, content)}", } return updated return updated def _build_messages(self, inputs, params): use_system_role = bool(params.get("use_system_role", False)) if isinstance(inputs, list): messages = self._coerce_messages(inputs) if not messages: return [{"role": "user", "content": ""}] if use_system_role: return self._append_response_rule_to_last_user(messages) return self._append_response_rule_to_last_user(self._merge_system_into_first_user(messages)) if isinstance(inputs, dict) and "messages" in inputs: return self._build_messages(inputs["messages"], params) instruction = DEFAULT_INSTRUCTION if isinstance(inputs, dict) and inputs.get("system"): instruction = str(inputs["system"]) if isinstance(inputs, dict) and "context" in inputs and "question" in inputs: question = str(inputs["question"]) context = str(inputs["context"]) response_rule = self._response_rule(question, context) if use_system_role: return [ {"role": "system", "content": instruction}, {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {question}\n\n{response_rule}"}, ] return [ { "role": "user", "content": f"Instruction: {instruction}\n\nContext:\n{context}\n\nQuestion: {question}\n\n{response_rule}", } ] prompt_text = str(inputs.get("prompt", "")) if isinstance(inputs, dict) else str(inputs) if use_system_role: return self._append_response_rule_to_last_user([ {"role": "system", "content": instruction}, {"role": "user", "content": prompt_text}, ]) return self._append_response_rule_to_last_user([ {"role": "user", "content": f"Instruction: {instruction}\n\n{prompt_text}"}, ]) def __call__(self, data): inputs = data.get("inputs", "") params = data.get("parameters", {}) or {} max_new_tokens = min(int(params.get("max_new_tokens", 64)), 512) max_input_tokens = int(params.get("max_input_tokens", 4096)) max_input_tokens = max(512, min(max_input_tokens, 8192)) do_sample = bool(params.get("do_sample", False)) temperature = float(params.get("temperature", 0.7 if do_sample else 0.0)) top_p = float(params.get("top_p", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) no_repeat_ngram_size = int(params.get("no_repeat_ngram_size", 0)) debug = bool(params.get("debug", False)) messages = self._build_messages(inputs, params) prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) enc = self.tokenizer( prompt, return_tensors="pt", add_special_tokens=False, truncation=True, max_length=max_input_tokens, ) truncated = enc["input_ids"].shape[-1] >= max_input_tokens if torch.cuda.is_available(): enc = {key: value.to(self.model.device) for key, value in enc.items()} eos_token_id = getattr(self.model.config, "eos_token_id", None) if eos_token_id is None: eos_token_id = self.tokenizer.eos_token_id generate_kwargs = dict( **enc, max_new_tokens=max_new_tokens, do_sample=do_sample, repetition_penalty=repetition_penalty, renormalize_logits=True, pad_token_id=self.tokenizer.pad_token_id, eos_token_id=eos_token_id, ) if do_sample: generate_kwargs["temperature"] = max(temperature, 1e-5) generate_kwargs["top_p"] = top_p if no_repeat_ngram_size > 0: generate_kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size with torch.no_grad(): out = self.model.generate(**generate_kwargs) generated_ids = out[0][enc["input_ids"].shape[-1]:] text = self.tokenizer.decode(generated_ids, skip_special_tokens=True).strip() response = {"generated_text": text} if debug: response["prompt"] = prompt response["messages"] = messages response["input_tokens"] = int(enc["input_ids"].shape[-1]) response["truncated"] = bool(truncated) return response