#!/usr/bin/env python3 """ 检测并修复 tokenizer_config.json 中的两类问题: 1. tokenizer_class 在 transformers 中不存在(如 TokenizersBackend) 2. extra_special_tokens 为 list 格式(transformers 要求 dict) 若存在问题,将 tokenizer 文件复制到 /tmp/fixed_tokenizer/ 并修复, 最后将修复目录路径输出到 stdout。若无需修复,输出为空。 """ import os import sys import json import shutil MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model") OUT_DIR = "/tmp/fixed_tokenizer" def main(): cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json") if not os.path.exists(cfg_path): return with open(cfg_path) as f: cfg = json.load(f) fixes = [] # --- 检测 1:tokenizer_class 是否在 transformers 中存在 --- tokenizer_class = cfg.get("tokenizer_class", "") bad_tokenizer_class = False if tokenizer_class: import transformers if getattr(transformers, tokenizer_class, None) is None: bad_tokenizer_class = True fixes.append(f"tokenizer_class '{tokenizer_class}' not found in transformers") # --- 检测 2:extra_special_tokens 是否为 list 格式 --- bad_extra_special_tokens = ( "extra_special_tokens" in cfg and isinstance(cfg["extra_special_tokens"], list) ) if bad_extra_special_tokens: fixes.append("extra_special_tokens is a list, expected dict") if not fixes: return # 无需修复 # 复制 tokenizer 文件到临时目录 os.makedirs(OUT_DIR, exist_ok=True) for fname in [ "tokenizer.json", "tokenizer_config.json", "special_tokens_map.json", "vocab.json", "merges.txt", "tokenizer.model", ]: src = os.path.join(MODEL_DIR, fname) if os.path.exists(src): shutil.copy(src, OUT_DIR) # --- 修复 1:替换 tokenizer_class --- if bad_tokenizer_class: files = os.listdir(MODEL_DIR) if "tokenizer.json" in files: fixed_class = "PreTrainedTokenizerFast" elif "tokenizer.model" in files: fixed_class = "LlamaTokenizer" elif "vocab.json" in files and "merges.txt" in files: fixed_class = "GPT2TokenizerFast" else: fixed_class = "PreTrainedTokenizerFast" cfg["tokenizer_class"] = fixed_class print( f"[fix_tokenizer] tokenizer_class: '{tokenizer_class}' → '{fixed_class}'", file=sys.stderr, ) # --- 修复 2:extra_special_tokens list → dict --- if bad_extra_special_tokens: orig_list = cfg["extra_special_tokens"] cfg["extra_special_tokens"] = {token: token for token in orig_list} print( f"[fix_tokenizer] extra_special_tokens: list({len(orig_list)}) → dict", file=sys.stderr, ) with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f: json.dump(cfg, f, indent=2) print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获 main()