更新 NV A100 Patched 镜像合并/fix_tokenizer.py
This commit is contained in:
@@ -1,6 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import json
|
import json
|
||||||
|
import transformers
|
||||||
|
import inspect
|
||||||
from detect_tokenizer import detect
|
from detect_tokenizer import detect
|
||||||
|
|
||||||
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
||||||
@@ -33,18 +35,28 @@ if os.path.exists(cfg_path):
|
|||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
|
|
||||||
if typ == "fast":
|
VALID_CLASSES = {
|
||||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
name for name, obj in inspect.getmembers(transformers)
|
||||||
elif typ == "sentencepiece":
|
if inspect.isclass(obj) and "Tokenizer" in name
|
||||||
cfg["tokenizer_class"] = "LlamaTokenizer"
|
}
|
||||||
elif typ == "bpe":
|
|
||||||
cfg["tokenizer_class"] = "GPT2TokenizerFast"
|
|
||||||
else:
|
|
||||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
|
||||||
|
|
||||||
bad_classes = ["TokenizersBackend", "TiktokenTokenizer"]
|
BAD_CLASSES = {"TokenizersBackend", "TiktokenTokenizer"}
|
||||||
if orig_cls in bad_classes:
|
|
||||||
print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}")
|
FALLBACK = {
|
||||||
|
"fast": "PreTrainedTokenizerFast",
|
||||||
|
"sentencepiece": "LlamaTokenizer",
|
||||||
|
"bpe": "GPT2TokenizerFast",
|
||||||
|
}
|
||||||
|
|
||||||
|
if orig_cls and orig_cls in VALID_CLASSES and orig_cls not in BAD_CLASSES:
|
||||||
|
print(f"[fix_tokenizer] tokenizer_class '{orig_cls}' is valid, skip override")
|
||||||
|
else:
|
||||||
|
fallback = FALLBACK.get(typ, "PreTrainedTokenizerFast")
|
||||||
|
if orig_cls:
|
||||||
|
print(f"[fix] override bad tokenizer_class: {orig_cls} → {fallback}")
|
||||||
|
else:
|
||||||
|
print(f"[fix] tokenizer_class missing, set to: {fallback}")
|
||||||
|
cfg["tokenizer_class"] = fallback
|
||||||
|
|
||||||
with open(cfg_path, "w") as f:
|
with open(cfg_path, "w") as f:
|
||||||
json.dump(cfg, f)
|
json.dump(cfg, f)
|
||||||
|
|||||||
Reference in New Issue
Block a user