初始化项目,由ModelHub XC社区提供模型
Model: kojima-lab/molcrawl-molecule-nat-lang-mol-instructions-gpt2-small Source: Original Platform
This commit is contained in:
198
sample_inference.py
Normal file
198
sample_inference.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Sample inference script for molcrawl-molecule-nat-lang-gpt2-small.
|
||||
|
||||
This model is a GPT-2 small (124M params) foundation model pretrained on
|
||||
molecule-related natural language data using a standard GPT-2 BPE tokenizer
|
||||
(vocab_size=50257).
|
||||
|
||||
Key fix over the 20260316 version:
|
||||
- 20260316: Used MinimalTokenizer with Python hash() — non-deterministic,
|
||||
decode() impossible, data/model mismatch.
|
||||
- 20260325: Uses GPT2TokenizerFast (BPE) — fully deterministic, decodable.
|
||||
|
||||
Usage:
|
||||
# From HuggingFace Hub
|
||||
python sample_inference.py
|
||||
|
||||
# From local checkpoint dir
|
||||
MODEL_PATH=/path/to/checkpoint python sample_inference.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
except ImportError:
|
||||
print("ERROR: Install dependencies: pip install transformers torch")
|
||||
sys.exit(1)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config
|
||||
# ---------------------------------------------------------------------------
|
||||
MODEL_PATH = os.environ.get("MODEL_PATH", "kojima-lab/molcrawl-molecule-nat-lang-gpt2-small")
|
||||
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
DEMO_TEXTS = [
|
||||
"The compound with SMILES CC(=O)O is",
|
||||
"This molecule has a molecular weight of",
|
||||
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin, which",
|
||||
"In drug discovery, the key property of this compound is",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TEST 1: Tokenizer determinism (validates 20260316 defect is resolved)
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_tokenizer_determinism(tokenizer):
|
||||
"""
|
||||
20260316 defect: MinimalTokenizer used abs(hash(token)) % 50000 + 2.
|
||||
Python hash() is PYTHONHASHSEED-dependent -> different IDs across processes.
|
||||
20260325 fix: GPT2TokenizerFast (BPE) -> fully deterministic.
|
||||
"""
|
||||
print("\n[TEST 1] Tokenizer Determinism")
|
||||
print("-" * 40)
|
||||
text = "The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin."
|
||||
|
||||
calls = [tokenizer.encode(text) for _ in range(5)]
|
||||
all_equal = all(c == calls[0] for c in calls)
|
||||
|
||||
print(f" Input : {text!r}")
|
||||
print(f" IDs : {calls[0][:10]}...")
|
||||
print(f" Deterministic (5 calls identical): {'PASS ✓' if all_equal else 'FAIL ✗'}")
|
||||
print(f" vocab_size : {tokenizer.vocab_size}")
|
||||
print(f" max token ID: {max(calls[0])} (< vocab_size: {max(calls[0]) < tokenizer.vocab_size} ✓)")
|
||||
|
||||
# Compare with 20260316 behaviour (MinimalTokenizer with fixed seed for demo)
|
||||
# When PYTHONHASHSEED varies: abs(hash('aspirin')) % 50000 + 2 will differ.
|
||||
# Demonstrating the class of defect:
|
||||
|
||||
# Simulate two different hash seeds via salt (cannot change PYTHONHASHSEED mid-process)
|
||||
# Instead, show the formula directly
|
||||
tok_str = "aspirin"
|
||||
h1 = abs(hash(tok_str)) % 50000 + 2
|
||||
# A different Python process with different PYTHONHASHSEED would give different h1
|
||||
print(f"\n [Defect demo] MinimalTokenizer hash('aspirin') % 50000 + 2 = {h1}")
|
||||
print(" [Defect demo] This value changes across Python processes (PYTHONHASHSEED=random)")
|
||||
print(f" [Fixed] GPT-2 BPE: 'aspirin' -> {tokenizer.encode('aspirin')} (always)")
|
||||
|
||||
return all_equal
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TEST 2: Round-trip encode → decode
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_round_trip(tokenizer):
|
||||
"""Verify encode → decode produces the original text (impossible with MinimalTokenizer)."""
|
||||
print("\n[TEST 2] Round-trip Encode → Decode")
|
||||
print("-" * 40)
|
||||
texts = [
|
||||
"The SMILES CC(=O)Oc1ccccc1C(=O)O represents aspirin.",
|
||||
"Drug discovery requires understanding molecular properties.",
|
||||
"CC(N)C(=O)O is alanine, an amino acid.",
|
||||
]
|
||||
all_pass = True
|
||||
for text in texts:
|
||||
ids = tokenizer.encode(text)
|
||||
decoded = tokenizer.decode(ids, skip_special_tokens=True)
|
||||
match = text.strip() == decoded.strip()
|
||||
all_pass = all_pass and match
|
||||
status = "PASS ✓" if match else "FAIL ✗"
|
||||
print(f" {status} {text[:50]!r}")
|
||||
if not match:
|
||||
print(f" decoded: {decoded!r}")
|
||||
return all_pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TEST 3: Vocabulary coverage of molecule-specific tokens
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_molecule_tokens(tokenizer):
|
||||
"""Check that molecule-specific strings tokenize to reasonable sequences."""
|
||||
print("\n[TEST 3] Molecule Token Coverage")
|
||||
print("-" * 40)
|
||||
examples = {
|
||||
"CC(=O)O": "acetic acid (SMILES)",
|
||||
"c1ccccc1": "benzene ring (SMILES)",
|
||||
"CC(=O)Oc1ccccc1C(=O)O": "aspirin (SMILES)",
|
||||
"NH2": "amine group",
|
||||
"molecular weight": "NL phrase",
|
||||
"IC50": "pharmacology term",
|
||||
"ADMET": "drug property acronym",
|
||||
}
|
||||
for tok_str, desc in examples.items():
|
||||
ids = tokenizer.encode(tok_str)
|
||||
print(f" {desc:35s} -> {len(ids):2d} tokens {ids[:6]}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TEST 4: Text generation
|
||||
# ---------------------------------------------------------------------------
|
||||
def test_generation(model, tokenizer):
|
||||
"""Generate continuations for molecule-related prompts."""
|
||||
print("\n[TEST 4] Text Generation")
|
||||
print("-" * 40)
|
||||
model.eval()
|
||||
for prompt in DEMO_TEXTS:
|
||||
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
|
||||
with torch.no_grad():
|
||||
out = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=60,
|
||||
do_sample=True,
|
||||
temperature=0.85,
|
||||
top_p=0.92,
|
||||
pad_token_id=tokenizer.eos_token_id,
|
||||
repetition_penalty=1.1,
|
||||
)
|
||||
generated = tokenizer.decode(out[0], skip_special_tokens=True)
|
||||
print(f"\n Prompt : {prompt!r}")
|
||||
print(f" Output : {generated!r}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main
|
||||
# ---------------------------------------------------------------------------
|
||||
def main():
|
||||
print("=" * 60)
|
||||
print("MolCrawl molecule_nat_lang GPT-2 small — Inference Demo")
|
||||
print(f"Model : {MODEL_PATH}")
|
||||
print(f"Device: {DEVICE}")
|
||||
print("=" * 60)
|
||||
|
||||
# Load
|
||||
print("\nLoading tokenizer...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
||||
print(f" class : {type(tokenizer).__name__}")
|
||||
print(f" vocab : {tokenizer.vocab_size}")
|
||||
|
||||
print("Loading model...")
|
||||
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH).to(DEVICE)
|
||||
model.eval()
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
print(f" params : {n_params:,}")
|
||||
|
||||
# Run tests
|
||||
r1 = test_tokenizer_determinism(tokenizer)
|
||||
r2 = test_round_trip(tokenizer)
|
||||
test_molecule_tokens(tokenizer)
|
||||
test_generation(model, tokenizer)
|
||||
|
||||
# Summary
|
||||
print("\n" + "=" * 60)
|
||||
print("Summary")
|
||||
print("=" * 60)
|
||||
print(f" Tokenizer determinism : {'PASS ✓' if r1 else 'FAIL ✗'}")
|
||||
print(f" Round-trip decode : {'PASS ✓' if r2 else 'FAIL ✗'}")
|
||||
print(" Text generation : done")
|
||||
if r1 and r2:
|
||||
print("\n All validation tests PASSED.")
|
||||
print(" Tokenizer defect from 20260316 (MinimalTokenizer hash-based) is RESOLVED.")
|
||||
else:
|
||||
print("\n Some tests FAILED — please check the output above.")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user