Files
GENERator-eukaryote-1.2b-base/modeling_generator.py

237 lines
9.7 KiB
Python
Raw Permalink Normal View History

"""
GENERator with bp-level generation and scoring.
generate_bp() plugs into the standard HF generate() pipeline via a
LogitsProcessor no internal methods are overridden, so it is compatible
with any transformers version.
"""
import torch
import torch.nn.functional as F
from transformers import LlamaForCausalLM, LogitsProcessor, LogitsProcessorList
from typing import Union
BASE_TO_IDX = {"A": 0, "T": 1, "C": 2, "G": 3, "N": -1}
IDX_TO_BASE = {0: "A", 1: "T", 2: "C", 3: "G", -1: "N"}
class _BPLogitsProcessor(LogitsProcessor):
"""Forces token selection to use per-base marginal probabilities.
Runs LAST in the logits-processor chain so that temperature / top-k /
top-p etc. influence the marginal distributions before base selection.
"""
def __init__(self, kmer_ids, bp_base_index, flat_idx_to_token_id, bp_powers, k, do_sample):
self.kmer_ids = kmer_ids
self.bp_base_index = bp_base_index
self.flat_idx_to_token_id = flat_idx_to_token_id
self.bp_powers = bp_powers
self.k = k
self.do_sample = do_sample
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
B = scores.shape[0]
kmer_probs = F.softmax(scores[:, self.kmer_ids].float(), dim=-1) # [B, num_kmers]
# Marginalise to per-base probabilities [B, k, 4]
bp_probs = torch.zeros(B, self.k, 4, device=scores.device, dtype=kmer_probs.dtype)
for pos in range(self.k):
idx = self.bp_base_index[pos] # [num_kmers] in {0,1,2,3}
for nt in range(4):
bp_probs[:, pos, nt] = kmer_probs[:, idx == nt].sum(dim=-1)
if self.do_sample:
base_indices = torch.multinomial(bp_probs.view(-1, 4), 1).view(B, self.k)
else:
base_indices = bp_probs.argmax(dim=-1) # [B, k]
flat_idx = (base_indices * self.bp_powers).sum(dim=-1) # [B]
selected = self.flat_idx_to_token_id[flat_idx] # [B]
# One-hot: both argmax and multinomial land on the bp-selected token
new_scores = torch.full_like(scores, float("-inf"))
new_scores.scatter_(1, selected.unsqueeze(1), 0.0)
return new_scores
class GENERatorForCausalLM(LlamaForCausalLM):
"""LlamaForCausalLM with bp-level autoregressive generation.
Inherits all standard functionality (forward, generate, etc.)
and adds generate_bp() for base-pair independent generation.
The tokenizer is automatically set up when loading the model with from_pretrained().
"""
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Load model and automatically setup tokenizer if available."""
model = super().from_pretrained(*args, **kwargs)
model_path = args[0] if len(args) > 0 else kwargs.get('pretrained_model_name_or_path')
if model_path:
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model.setup_tokenizer(tokenizer)
print(f"Tokenizer automatically loaded and configured for bp-level scoring")
except Exception as e:
print(f"Could not auto-load tokenizer: {e}")
print(f" Call model.setup_tokenizer(tokenizer) manually if needed")
return model
def setup_tokenizer(self, tokenizer):
"""Cache tokenizer and precompute lookup tables for bp-level operations."""
self.tokenizer = tokenizer
k = tokenizer.k
self.k = k
device = next(self.parameters()).device
# Build ordered kmer list from the tokenizer's DNA vocab
kmer_items = sorted(
[
(kmer, tid)
for kmer, tid in tokenizer.vocab.items()
if len(kmer) == k and all(b in "ATCG" for b in kmer)
],
key=lambda x: x[1],
)
kmers = [item[0] for item in kmer_items]
kmer_ids = [item[1] for item in kmer_items]
num_kmers = len(kmer_ids)
kmer_ids_tensor = torch.tensor(kmer_ids, dtype=torch.long, device=device)
self.register_buffer("_kmer_ids", kmer_ids_tensor, persistent=False)
# bp_base_index[pos, j] = base index (0-3) of kmer j at position pos
bp_base_index = torch.zeros(k, num_kmers, dtype=torch.long)
for j, kmer in enumerate(kmers):
for pos, base in enumerate(kmer):
bp_base_index[pos, j] = BASE_TO_IDX[base]
self.register_buffer("_bp_base_index", bp_base_index.to(device), persistent=False)
bp_powers = torch.tensor(
[4 ** i for i in range(k - 1, -1, -1)], dtype=torch.long, device=device
)
self.register_buffer("_bp_powers", bp_powers, persistent=False)
# flat kmer index -> token id (flat index = sum base_idx[i] * 4^(k-1-i))
flat_to_tid = torch.zeros(num_kmers, dtype=torch.long, device=device)
for j, (kmer, tid) in enumerate(kmer_items):
flat_idx = sum(BASE_TO_IDX[c] * (4 ** (k - 1 - i)) for i, c in enumerate(kmer))
flat_to_tid[flat_idx] = tid
self.register_buffer("_flat_idx_to_token_id", flat_to_tid, persistent=False)
def compute_bp_probs(self, logits):
"""Compute per-base marginal probabilities from token logits.
Args:
logits: [B, V] or [B, L, V]
Returns:
bp_probs: [B, k, 4] or [B, L, k, 4]
"""
squeeze = logits.dim() == 2
if squeeze:
logits = logits.unsqueeze(1)
kmer_logits = logits[:, :, self._kmer_ids]
kmer_probs = F.softmax(kmer_logits.float(), dim=-1)
B, L, _ = kmer_probs.shape
bp_probs = torch.zeros(B, L, self.k, 4, device=logits.device, dtype=kmer_probs.dtype)
for pos in range(self.k):
idx = self._bp_base_index[pos]
for nt in range(4):
bp_probs[:, :, pos, nt] = kmer_probs[:, :, idx == nt].sum(dim=-1)
return bp_probs.squeeze(1) if squeeze else bp_probs
def generate(self, inputs=None, generation_config=None, **kwargs):
"""Like generate(), but each token is selected base-by-base from marginal distributions.
Temperature, top_k, top_p, repetition_penalty etc. all apply as usual
they run before the bp processor and shift the marginal distributions.
Output shape and type are identical to generate().
"""
assert hasattr(self, "_bp_base_index"), "Call setup_tokenizer(tokenizer) first"
gc = generation_config or self.generation_config
do_sample = kwargs.get("do_sample", getattr(gc, "do_sample", False))
bp_proc = _BPLogitsProcessor(
kmer_ids=self._kmer_ids,
bp_base_index=self._bp_base_index,
flat_idx_to_token_id=self._flat_idx_to_token_id,
bp_powers=self._bp_powers,
k=self.k,
do_sample=do_sample,
)
existing = list(kwargs.pop("logits_processor", None) or [])
kwargs["logits_processor"] = LogitsProcessorList(existing + [bp_proc])
return super().generate(inputs=inputs, generation_config=generation_config, **kwargs)
@torch.no_grad()
def score_sequence(self, sequences: Union[str, list]):
"""Score DNA sequence(s) at base resolution.
Returns per-base probability distributions and the probability of the
actual base at each position, given all preceding context.
Args:
sequences: single DNA string or list of DNA strings (ACGT only)
Returns:
(bp_probs, actual_probs) for a single sequence, or
(list of bp_probs, list of actual_probs) for a batch.
bp_probs[i]: [seq_len_i, 4] P(base | context) at each position
actual_probs[i]: [seq_len_i] P(actual base | context)
"""
assert hasattr(self, "tokenizer"), "Call setup_tokenizer(tokenizer) first"
is_single = isinstance(sequences, str)
if is_single:
sequences = [sequences]
original_lens = [len(s) for s in sequences]
# Right-pad to multiple of k with 'A' (matches tokenizer convention)
padded = []
for s in sequences:
r = len(s) % self.k
padded.append(s + "A" * (self.k - r) if r else s)
# Prepend BOS manually (training format)
tagged = ["<s>" + s for s in padded]
inputs = self.tokenizer(
tagged, return_tensors="pt", padding=True, add_special_tokens=False
)
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
logits = self(input_ids, attention_mask=attention_mask, return_dict=True).logits
bp_probs_all = self.compute_bp_probs(logits) # [B, L, k, 4]
bp_results, actual_results = [], []
for i, (seq, orig_len, pad_seq) in enumerate(zip(sequences, original_lens, padded)):
num_tokens = len(pad_seq) // self.k
# logits[t] predicts token t+1; logits[0] (from <s>) predicts token 1
seq_bp = bp_probs_all[i, :num_tokens] # [num_tokens, k, 4]
seq_bp = seq_bp.reshape(-1, 4)[:orig_len] # [orig_len, 4]
actual = self._extract_actual_probs(seq_bp, seq)
bp_results.append(seq_bp)
actual_results.append(actual)
if is_single:
return bp_results[0], actual_results[0]
return bp_results, actual_results
def _extract_actual_probs(self, bp_probs: torch.Tensor, sequence: str) -> torch.Tensor:
actual = torch.zeros(len(sequence), device=bp_probs.device, dtype=bp_probs.dtype)
for i, base in enumerate(sequence):
actual[i] = bp_probs[i].max() if base == "N" else bp_probs[i, BASE_TO_IDX[base]]
return actual