diff --git a/NV A100 Patched 镜像合并/fix_tokenizer.py b/NV A100 Patched 镜像合并/fix_tokenizer.py new file mode 100644 index 0000000..83f3e14 --- /dev/null +++ b/NV A100 Patched 镜像合并/fix_tokenizer.py @@ -0,0 +1,52 @@ +import os +import shutil +import json +from detect_tokenizer import detect + +MODEL_DIR = os.environ.get("MODEL_DIR", "/model") +OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer") + +os.makedirs(OUT_DIR, exist_ok=True) + +def copy_if_exists(name): + src = os.path.join(MODEL_DIR, name) + if os.path.exists(src): + shutil.copy(src, OUT_DIR) + +for f in [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "vocab.json", + "merges.txt", + "tokenizer.model", +]: + copy_if_exists(f) + +typ, orig_cls = detect(MODEL_DIR) + +cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json") + +if os.path.exists(cfg_path): + with open(cfg_path) as f: + cfg = json.load(f) +else: + cfg = {} + +if typ == "fast": + cfg["tokenizer_class"] = "PreTrainedTokenizerFast" +elif typ == "sentencepiece": + cfg["tokenizer_class"] = "LlamaTokenizer" +elif typ == "bpe": + cfg["tokenizer_class"] = "GPT2TokenizerFast" +else: + cfg["tokenizer_class"] = "PreTrainedTokenizerFast" + +bad_classes = ["TokenizersBackend", "TiktokenTokenizer"] +if orig_cls in bad_classes: + print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}") + +with open(cfg_path, "w") as f: + json.dump(cfg, f) + +print(f"[fix_tokenizer] done → {OUT_DIR}") \ No newline at end of file