import os import shutil import json import sys import transformers import inspect sys.path.insert(0, '/opt') from detect_tokenizer import detect MODEL_DIR = os.environ.get("MODEL_DIR", "/model") OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer") os.makedirs(OUT_DIR, exist_ok=True) def copy_if_exists(name): src = os.path.join(MODEL_DIR, name) if os.path.exists(src): shutil.copy(src, OUT_DIR) for f in [ "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.model", ]: copy_if_exists(f) typ, orig_cls = detect(MODEL_DIR) cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json") if os.path.exists(cfg_path): with open(cfg_path) as f: cfg = json.load(f) else: cfg = {} VALID_CLASSES = { name for name, obj in inspect.getmembers(transformers) if inspect.isclass(obj) and "Tokenizer" in name } BAD_CLASSES = {"TokenizersBackend", "TiktokenTokenizer"} FALLBACK = { "fast": "PreTrainedTokenizerFast", "sentencepiece": "LlamaTokenizer", "bpe": "GPT2TokenizerFast", } if orig_cls and orig_cls in VALID_CLASSES and orig_cls not in BAD_CLASSES: print(f"[fix_tokenizer] tokenizer_class '{orig_cls}' is valid, skip override") else: fallback = FALLBACK.get(typ, "PreTrainedTokenizerFast") if orig_cls: print(f"[fix] override bad tokenizer_class: {orig_cls} → {fallback}") else: print(f"[fix] tokenizer_class missing, set to: {fallback}") cfg["tokenizer_class"] = fallback with open(cfg_path, "w") as f: json.dump(cfg, f) print(f"[fix_tokenizer] done → {OUT_DIR}")