""" Sample inference script for molcrawl-molecule-nat-lang-gpt2-small. This model is a GPT-2 small (124M params) foundation model pretrained on molecule-related natural language data using a standard GPT-2 BPE tokenizer (vocab_size=50257). Key fix over the 20260316 version: - 20260316: Used MinimalTokenizer with Python hash() — non-deterministic, decode() impossible, data/model mismatch. - 20260325: Uses GPT2TokenizerFast (BPE) — fully deterministic, decodable. Usage: # From HuggingFace Hub python sample_inference.py # From local checkpoint dir MODEL_PATH=/path/to/checkpoint python sample_inference.py """ import os import sys try: import torch from transformers import AutoModelForCausalLM, AutoTokenizer except ImportError: print("ERROR: Install dependencies: pip install transformers torch") sys.exit(1) # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- MODEL_PATH = os.environ.get("MODEL_PATH", "kojima-lab/molcrawl-molecule-nat-lang-gpt2-small") DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DEMO_TEXTS = [ "The compound with SMILES CC(=O)O is", "This molecule has a molecular weight of", "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin, which", "In drug discovery, the key property of this compound is", ] # --------------------------------------------------------------------------- # TEST 1: Tokenizer determinism (validates 20260316 defect is resolved) # --------------------------------------------------------------------------- def test_tokenizer_determinism(tokenizer): """ 20260316 defect: MinimalTokenizer used abs(hash(token)) % 50000 + 2. Python hash() is PYTHONHASHSEED-dependent -> different IDs across processes. 20260325 fix: GPT2TokenizerFast (BPE) -> fully deterministic. """ print("\n[TEST 1] Tokenizer Determinism") print("-" * 40) text = "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin." calls = [tokenizer.encode(text) for _ in range(5)] all_equal = all(c == calls[0] for c in calls) print(f" Input : {text!r}") print(f" IDs : {calls[0][:10]}...") print(f" Deterministic (5 calls identical): {'PASS ✓' if all_equal else 'FAIL ✗'}") print(f" vocab_size : {tokenizer.vocab_size}") print(f" max token ID: {max(calls[0])} (< vocab_size: {max(calls[0]) < tokenizer.vocab_size} ✓)") # Compare with 20260316 behaviour (MinimalTokenizer with fixed seed for demo) # When PYTHONHASHSEED varies: abs(hash('aspirin')) % 50000 + 2 will differ. # Demonstrating the class of defect: # Simulate two different hash seeds via salt (cannot change PYTHONHASHSEED mid-process) # Instead, show the formula directly tok_str = "aspirin" h1 = abs(hash(tok_str)) % 50000 + 2 # A different Python process with different PYTHONHASHSEED would give different h1 print(f"\n [Defect demo] MinimalTokenizer hash('aspirin') % 50000 + 2 = {h1}") print(" [Defect demo] This value changes across Python processes (PYTHONHASHSEED=random)") print(f" [Fixed] GPT-2 BPE: 'aspirin' -> {tokenizer.encode('aspirin')} (always)") return all_equal # --------------------------------------------------------------------------- # TEST 2: Round-trip encode → decode # --------------------------------------------------------------------------- def test_round_trip(tokenizer): """Verify encode → decode produces the original text (impossible with MinimalTokenizer).""" print("\n[TEST 2] Round-trip Encode → Decode") print("-" * 40) texts = [ "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin.", "Drug discovery requires understanding molecular properties.", "CC(N)C(=O)O is alanine, an amino acid.", ] all_pass = True for text in texts: ids = tokenizer.encode(text) decoded = tokenizer.decode(ids, skip_special_tokens=True) match = text.strip() == decoded.strip() all_pass = all_pass and match status = "PASS ✓" if match else "FAIL ✗" print(f" {status} {text[:50]!r}") if not match: print(f" decoded: {decoded!r}") return all_pass # --------------------------------------------------------------------------- # TEST 3: Vocabulary coverage of molecule-specific tokens # --------------------------------------------------------------------------- def test_molecule_tokens(tokenizer): """Check that molecule-specific strings tokenize to reasonable sequences.""" print("\n[TEST 3] Molecule Token Coverage") print("-" * 40) examples = { "CC(=O)O": "acetic acid (SMILES)", "c1ccccc1": "benzene ring (SMILES)", "CC(=O)Oc1ccccc1C(=O)O": "aspirin (SMILES)", "NH2": "amine group", "molecular weight": "NL phrase", "IC50": "pharmacology term", "ADMET": "drug property acronym", } for tok_str, desc in examples.items(): ids = tokenizer.encode(tok_str) print(f" {desc:35s} -> {len(ids):2d} tokens {ids[:6]}") # --------------------------------------------------------------------------- # TEST 4: Text generation # --------------------------------------------------------------------------- def test_generation(model, tokenizer): """Generate continuations for molecule-related prompts.""" print("\n[TEST 4] Text Generation") print("-" * 40) model.eval() for prompt in DEMO_TEXTS: inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE) with torch.no_grad(): out = model.generate( **inputs, max_new_tokens=60, do_sample=True, temperature=0.85, top_p=0.92, pad_token_id=tokenizer.eos_token_id, repetition_penalty=1.1, ) generated = tokenizer.decode(out[0], skip_special_tokens=True) print(f"\n Prompt : {prompt!r}") print(f" Output : {generated!r}") # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): print("=" * 60) print("MolCrawl molecule_nat_lang GPT-2 small — Inference Demo") print(f"Model : {MODEL_PATH}") print(f"Device: {DEVICE}") print("=" * 60) # Load print("\nLoading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH) print(f" class : {type(tokenizer).__name__}") print(f" vocab : {tokenizer.vocab_size}") print("Loading model...") model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE) model.eval() n_params = sum(p.numel() for p in model.parameters()) print(f" params : {n_params:,}") # Run tests r1 = test_tokenizer_determinism(tokenizer) r2 = test_round_trip(tokenizer) test_molecule_tokens(tokenizer) test_generation(model, tokenizer) # Summary print("\n" + "=" * 60) print("Summary") print("=" * 60) print(f" Tokenizer determinism : {'PASS ✓' if r1 else 'FAIL ✗'}") print(f" Round-trip decode : {'PASS ✓' if r2 else 'FAIL ✗'}") print(" Text generation : done") if r1 and r2: print("\n All validation tests PASSED.") print(" Tokenizer defect from 20260316 (MinimalTokenizer hash-based) is RESOLVED.") else: print("\n Some tests FAILED — please check the output above.") print("=" * 60) if __name__ == "__main__": main()