diff --git a/NV A100 Patched 镜像合并/fix_tokenizer.py b/NV A100 Patched 镜像合并/fix_tokenizer.py index 83f3e14..2981d52 100644 --- a/NV A100 Patched 镜像合并/fix_tokenizer.py +++ b/NV A100 Patched 镜像合并/fix_tokenizer.py @@ -1,6 +1,8 @@ import os import shutil import json +import transformers +import inspect from detect_tokenizer import detect MODEL_DIR = os.environ.get("MODEL_DIR", "/model") @@ -33,18 +35,28 @@ if os.path.exists(cfg_path): else: cfg = {} -if typ == "fast": - cfg["tokenizer_class"] = "PreTrainedTokenizerFast" -elif typ == "sentencepiece": - cfg["tokenizer_class"] = "LlamaTokenizer" -elif typ == "bpe": - cfg["tokenizer_class"] = "GPT2TokenizerFast" -else: - cfg["tokenizer_class"] = "PreTrainedTokenizerFast" +VALID_CLASSES = { + name for name, obj in inspect.getmembers(transformers) + if inspect.isclass(obj) and "Tokenizer" in name +} -bad_classes = ["TokenizersBackend", "TiktokenTokenizer"] -if orig_cls in bad_classes: - print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}") +BAD_CLASSES = {"TokenizersBackend", "TiktokenTokenizer"} + +FALLBACK = { + "fast": "PreTrainedTokenizerFast", + "sentencepiece": "LlamaTokenizer", + "bpe": "GPT2TokenizerFast", +} + +if orig_cls and orig_cls in VALID_CLASSES and orig_cls not in BAD_CLASSES: + print(f"[fix_tokenizer] tokenizer_class '{orig_cls}' is valid, skip override") +else: + fallback = FALLBACK.get(typ, "PreTrainedTokenizerFast") + if orig_cls: + print(f"[fix] override bad tokenizer_class: {orig_cls} → {fallback}") + else: + print(f"[fix] tokenizer_class missing, set to: {fallback}") + cfg["tokenizer_class"] = fallback with open(cfg_path, "w") as f: json.dump(cfg, f)