199 lines
7.5 KiB
Python
199 lines
7.5 KiB
Python
|
|
"""
|
||
|
|
Sample inference script for molcrawl-molecule-nat-lang-gpt2-small.
|
||
|
|
|
||
|
|
This model is a GPT-2 small (124M params) foundation model pretrained on
|
||
|
|
molecule-related natural language data using a standard GPT-2 BPE tokenizer
|
||
|
|
(vocab_size=50257).
|
||
|
|
|
||
|
|
Key fix over the 20260316 version:
|
||
|
|
- 20260316: Used MinimalTokenizer with Python hash() — non-deterministic,
|
||
|
|
decode() impossible, data/model mismatch.
|
||
|
|
- 20260325: Uses GPT2TokenizerFast (BPE) — fully deterministic, decodable.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
# From HuggingFace Hub
|
||
|
|
python sample_inference.py
|
||
|
|
|
||
|
|
# From local checkpoint dir
|
||
|
|
MODEL_PATH=/path/to/checkpoint python sample_inference.py
|
||
|
|
"""
|
||
|
|
|
||
|
|
import os
|
||
|
|
import sys
|
||
|
|
|
||
|
|
try:
|
||
|
|
import torch
|
||
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||
|
|
except ImportError:
|
||
|
|
print("ERROR: Install dependencies: pip install transformers torch")
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Config
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
MODEL_PATH = os.environ.get("MODEL_PATH", "kojima-lab/molcrawl-molecule-nat-lang-gpt2-small")
|
||
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
|
||
|
|
DEMO_TEXTS = [
|
||
|
|
"The compound with SMILES CC(=O)O is",
|
||
|
|
"This molecule has a molecular weight of",
|
||
|
|
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin, which",
|
||
|
|
"In drug discovery, the key property of this compound is",
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# TEST 1: Tokenizer determinism (validates 20260316 defect is resolved)
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
def test_tokenizer_determinism(tokenizer):
|
||
|
|
"""
|
||
|
|
20260316 defect: MinimalTokenizer used abs(hash(token)) % 50000 + 2.
|
||
|
|
Python hash() is PYTHONHASHSEED-dependent -> different IDs across processes.
|
||
|
|
20260325 fix: GPT2TokenizerFast (BPE) -> fully deterministic.
|
||
|
|
"""
|
||
|
|
print("\n[TEST 1] Tokenizer Determinism")
|
||
|
|
print("-" * 40)
|
||
|
|
text = "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin."
|
||
|
|
|
||
|
|
calls = [tokenizer.encode(text) for _ in range(5)]
|
||
|
|
all_equal = all(c == calls[0] for c in calls)
|
||
|
|
|
||
|
|
print(f" Input : {text!r}")
|
||
|
|
print(f" IDs : {calls[0][:10]}...")
|
||
|
|
print(f" Deterministic (5 calls identical): {'PASS ✓' if all_equal else 'FAIL ✗'}")
|
||
|
|
print(f" vocab_size : {tokenizer.vocab_size}")
|
||
|
|
print(f" max token ID: {max(calls[0])} (< vocab_size: {max(calls[0]) < tokenizer.vocab_size} ✓)")
|
||
|
|
|
||
|
|
# Compare with 20260316 behaviour (MinimalTokenizer with fixed seed for demo)
|
||
|
|
# When PYTHONHASHSEED varies: abs(hash('aspirin')) % 50000 + 2 will differ.
|
||
|
|
# Demonstrating the class of defect:
|
||
|
|
|
||
|
|
# Simulate two different hash seeds via salt (cannot change PYTHONHASHSEED mid-process)
|
||
|
|
# Instead, show the formula directly
|
||
|
|
tok_str = "aspirin"
|
||
|
|
h1 = abs(hash(tok_str)) % 50000 + 2
|
||
|
|
# A different Python process with different PYTHONHASHSEED would give different h1
|
||
|
|
print(f"\n [Defect demo] MinimalTokenizer hash('aspirin') % 50000 + 2 = {h1}")
|
||
|
|
print(" [Defect demo] This value changes across Python processes (PYTHONHASHSEED=random)")
|
||
|
|
print(f" [Fixed] GPT-2 BPE: 'aspirin' -> {tokenizer.encode('aspirin')} (always)")
|
||
|
|
|
||
|
|
return all_equal
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# TEST 2: Round-trip encode → decode
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
def test_round_trip(tokenizer):
|
||
|
|
"""Verify encode → decode produces the original text (impossible with MinimalTokenizer)."""
|
||
|
|
print("\n[TEST 2] Round-trip Encode → Decode")
|
||
|
|
print("-" * 40)
|
||
|
|
texts = [
|
||
|
|
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin.",
|
||
|
|
"Drug discovery requires understanding molecular properties.",
|
||
|
|
"CC(N)C(=O)O is alanine, an amino acid.",
|
||
|
|
]
|
||
|
|
all_pass = True
|
||
|
|
for text in texts:
|
||
|
|
ids = tokenizer.encode(text)
|
||
|
|
decoded = tokenizer.decode(ids, skip_special_tokens=True)
|
||
|
|
match = text.strip() == decoded.strip()
|
||
|
|
all_pass = all_pass and match
|
||
|
|
status = "PASS ✓" if match else "FAIL ✗"
|
||
|
|
print(f" {status} {text[:50]!r}")
|
||
|
|
if not match:
|
||
|
|
print(f" decoded: {decoded!r}")
|
||
|
|
return all_pass
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# TEST 3: Vocabulary coverage of molecule-specific tokens
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
def test_molecule_tokens(tokenizer):
|
||
|
|
"""Check that molecule-specific strings tokenize to reasonable sequences."""
|
||
|
|
print("\n[TEST 3] Molecule Token Coverage")
|
||
|
|
print("-" * 40)
|
||
|
|
examples = {
|
||
|
|
"CC(=O)O": "acetic acid (SMILES)",
|
||
|
|
"c1ccccc1": "benzene ring (SMILES)",
|
||
|
|
"CC(=O)Oc1ccccc1C(=O)O": "aspirin (SMILES)",
|
||
|
|
"NH2": "amine group",
|
||
|
|
"molecular weight": "NL phrase",
|
||
|
|
"IC50": "pharmacology term",
|
||
|
|
"ADMET": "drug property acronym",
|
||
|
|
}
|
||
|
|
for tok_str, desc in examples.items():
|
||
|
|
ids = tokenizer.encode(tok_str)
|
||
|
|
print(f" {desc:35s} -> {len(ids):2d} tokens {ids[:6]}")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# TEST 4: Text generation
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
def test_generation(model, tokenizer):
|
||
|
|
"""Generate continuations for molecule-related prompts."""
|
||
|
|
print("\n[TEST 4] Text Generation")
|
||
|
|
print("-" * 40)
|
||
|
|
model.eval()
|
||
|
|
for prompt in DEMO_TEXTS:
|
||
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
||
|
|
with torch.no_grad():
|
||
|
|
out = model.generate(
|
||
|
|
**inputs,
|
||
|
|
max_new_tokens=60,
|
||
|
|
do_sample=True,
|
||
|
|
temperature=0.85,
|
||
|
|
top_p=0.92,
|
||
|
|
pad_token_id=tokenizer.eos_token_id,
|
||
|
|
repetition_penalty=1.1,
|
||
|
|
)
|
||
|
|
generated = tokenizer.decode(out[0], skip_special_tokens=True)
|
||
|
|
print(f"\n Prompt : {prompt!r}")
|
||
|
|
print(f" Output : {generated!r}")
|
||
|
|
|
||
|
|
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
# Main
|
||
|
|
# ---------------------------------------------------------------------------
|
||
|
|
def main():
|
||
|
|
print("=" * 60)
|
||
|
|
print("MolCrawl molecule_nat_lang GPT-2 small — Inference Demo")
|
||
|
|
print(f"Model : {MODEL_PATH}")
|
||
|
|
print(f"Device: {DEVICE}")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
# Load
|
||
|
|
print("\nLoading tokenizer...")
|
||
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||
|
|
print(f" class : {type(tokenizer).__name__}")
|
||
|
|
print(f" vocab : {tokenizer.vocab_size}")
|
||
|
|
|
||
|
|
print("Loading model...")
|
||
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
|
||
|
|
model.eval()
|
||
|
|
n_params = sum(p.numel() for p in model.parameters())
|
||
|
|
print(f" params : {n_params:,}")
|
||
|
|
|
||
|
|
# Run tests
|
||
|
|
r1 = test_tokenizer_determinism(tokenizer)
|
||
|
|
r2 = test_round_trip(tokenizer)
|
||
|
|
test_molecule_tokens(tokenizer)
|
||
|
|
test_generation(model, tokenizer)
|
||
|
|
|
||
|
|
# Summary
|
||
|
|
print("\n" + "=" * 60)
|
||
|
|
print("Summary")
|
||
|
|
print("=" * 60)
|
||
|
|
print(f" Tokenizer determinism : {'PASS ✓' if r1 else 'FAIL ✗'}")
|
||
|
|
print(f" Round-trip decode : {'PASS ✓' if r2 else 'FAIL ✗'}")
|
||
|
|
print(" Text generation : done")
|
||
|
|
if r1 and r2:
|
||
|
|
print("\n All validation tests PASSED.")
|
||
|
|
print(" Tokenizer defect from 20260316 (MinimalTokenizer hash-based) is RESOLVED.")
|
||
|
|
else:
|
||
|
|
print("\n Some tests FAILED — please check the output above.")
|
||
|
|
print("=" * 60)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|