Files
molcrawl-molecule-nat-lang-…/sample_inference.py

199 lines
7.5 KiB
Python
Raw Permalink Normal View History

"""
Sample inference script for molcrawl-molecule-nat-lang-gpt2-small.
This model is a GPT-2 small (124M params) foundation model pretrained on
molecule-related natural language data using a standard GPT-2 BPE tokenizer
(vocab_size=50257).
Key fix over the 20260316 version:
- 20260316: Used MinimalTokenizer with Python hash() non-deterministic,
decode() impossible, data/model mismatch.
- 20260325: Uses GPT2TokenizerFast (BPE) fully deterministic, decodable.
Usage:
# From HuggingFace Hub
python sample_inference.py
# From local checkpoint dir
MODEL_PATH=/path/to/checkpoint python sample_inference.py
"""
import os
import sys
try:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
except ImportError:
print("ERROR: Install dependencies: pip install transformers torch")
sys.exit(1)
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_PATH = os.environ.get("MODEL_PATH", "kojima-lab/molcrawl-molecule-nat-lang-gpt2-small")
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEMO_TEXTS = [
"The compound with SMILES CC(=O)O is",
"This molecule has a molecular weight of",
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin, which",
"In drug discovery, the key property of this compound is",
]
# ---------------------------------------------------------------------------
# TEST 1: Tokenizer determinism (validates 20260316 defect is resolved)
# ---------------------------------------------------------------------------
def test_tokenizer_determinism(tokenizer):
"""
20260316 defect: MinimalTokenizer used abs(hash(token)) % 50000 + 2.
Python hash() is PYTHONHASHSEED-dependent -> different IDs across processes.
20260325 fix: GPT2TokenizerFast (BPE) -> fully deterministic.
"""
print("\n[TEST 1] Tokenizer Determinism")
print("-" * 40)
text = "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin."
calls = [tokenizer.encode(text) for _ in range(5)]
all_equal = all(c == calls[0] for c in calls)
print(f" Input : {text!r}")
print(f" IDs : {calls[0][:10]}...")
print(f" Deterministic (5 calls identical): {'PASS ✓' if all_equal else 'FAIL ✗'}")
print(f" vocab_size : {tokenizer.vocab_size}")
print(f" max token ID: {max(calls[0])} (< vocab_size: {max(calls[0]) < tokenizer.vocab_size} ✓)")
# Compare with 20260316 behaviour (MinimalTokenizer with fixed seed for demo)
# When PYTHONHASHSEED varies: abs(hash('aspirin')) % 50000 + 2 will differ.
# Demonstrating the class of defect:
# Simulate two different hash seeds via salt (cannot change PYTHONHASHSEED mid-process)
# Instead, show the formula directly
tok_str = "aspirin"
h1 = abs(hash(tok_str)) % 50000 + 2
# A different Python process with different PYTHONHASHSEED would give different h1
print(f"\n [Defect demo] MinimalTokenizer hash('aspirin') % 50000 + 2 = {h1}")
print(" [Defect demo] This value changes across Python processes (PYTHONHASHSEED=random)")
print(f" [Fixed] GPT-2 BPE: 'aspirin' -> {tokenizer.encode('aspirin')} (always)")
return all_equal
# ---------------------------------------------------------------------------
# TEST 2: Round-trip encode → decode
# ---------------------------------------------------------------------------
def test_round_trip(tokenizer):
"""Verify encode → decode produces the original text (impossible with MinimalTokenizer)."""
print("\n[TEST 2] Round-trip Encode → Decode")
print("-" * 40)
texts = [
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin.",
"Drug discovery requires understanding molecular properties.",
"CC(N)C(=O)O is alanine, an amino acid.",
]
all_pass = True
for text in texts:
ids = tokenizer.encode(text)
decoded = tokenizer.decode(ids, skip_special_tokens=True)
match = text.strip() == decoded.strip()
all_pass = all_pass and match
status = "PASS ✓" if match else "FAIL ✗"
print(f" {status} {text[:50]!r}")
if not match:
print(f" decoded: {decoded!r}")
return all_pass
# ---------------------------------------------------------------------------
# TEST 3: Vocabulary coverage of molecule-specific tokens
# ---------------------------------------------------------------------------
def test_molecule_tokens(tokenizer):
"""Check that molecule-specific strings tokenize to reasonable sequences."""
print("\n[TEST 3] Molecule Token Coverage")
print("-" * 40)
examples = {
"CC(=O)O": "acetic acid (SMILES)",
"c1ccccc1": "benzene ring (SMILES)",
"CC(=O)Oc1ccccc1C(=O)O": "aspirin (SMILES)",
"NH2": "amine group",
"molecular weight": "NL phrase",
"IC50": "pharmacology term",
"ADMET": "drug property acronym",
}
for tok_str, desc in examples.items():
ids = tokenizer.encode(tok_str)
print(f" {desc:35s} -> {len(ids):2d} tokens {ids[:6]}")
# ---------------------------------------------------------------------------
# TEST 4: Text generation
# ---------------------------------------------------------------------------
def test_generation(model, tokenizer):
"""Generate continuations for molecule-related prompts."""
print("\n[TEST 4] Text Generation")
print("-" * 40)
model.eval()
for prompt in DEMO_TEXTS:
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=60,
do_sample=True,
temperature=0.85,
top_p=0.92,
pad_token_id=tokenizer.eos_token_id,
repetition_penalty=1.1,
)
generated = tokenizer.decode(out[0], skip_special_tokens=True)
print(f"\n Prompt : {prompt!r}")
print(f" Output : {generated!r}")
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
print("=" * 60)
print("MolCrawl molecule_nat_lang GPT-2 small — Inference Demo")
print(f"Model : {MODEL_PATH}")
print(f"Device: {DEVICE}")
print("=" * 60)
# Load
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
print(f" class : {type(tokenizer).__name__}")
print(f" vocab : {tokenizer.vocab_size}")
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()
n_params = sum(p.numel() for p in model.parameters())
print(f" params : {n_params:,}")
# Run tests
r1 = test_tokenizer_determinism(tokenizer)
r2 = test_round_trip(tokenizer)
test_molecule_tokens(tokenizer)
test_generation(model, tokenizer)
# Summary
print("\n" + "=" * 60)
print("Summary")
print("=" * 60)
print(f" Tokenizer determinism : {'PASS ✓' if r1 else 'FAIL ✗'}")
print(f" Round-trip decode : {'PASS ✓' if r2 else 'FAIL ✗'}")
print(" Text generation : done")
if r1 and r2:
print("\n All validation tests PASSED.")
print(" Tokenizer defect from 20260316 (MinimalTokenizer hash-based) is RESOLVED.")
else:
print("\n Some tests FAILED — please check the output above.")
print("=" * 60)
if __name__ == "__main__":
main()