添加 NV A100 Patched 镜像合并/fix_tokenizer.py
This commit is contained in:
52
NV A100 Patched 镜像合并/fix_tokenizer.py
Normal file
52
NV A100 Patched 镜像合并/fix_tokenizer.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
from detect_tokenizer import detect
|
||||
|
||||
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
||||
OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer")
|
||||
|
||||
os.makedirs(OUT_DIR, exist_ok=True)
|
||||
|
||||
def copy_if_exists(name):
|
||||
src = os.path.join(MODEL_DIR, name)
|
||||
if os.path.exists(src):
|
||||
shutil.copy(src, OUT_DIR)
|
||||
|
||||
for f in [
|
||||
"tokenizer.json",
|
||||
"tokenizer_config.json",
|
||||
"special_tokens_map.json",
|
||||
"vocab.json",
|
||||
"merges.txt",
|
||||
"tokenizer.model",
|
||||
]:
|
||||
copy_if_exists(f)
|
||||
|
||||
typ, orig_cls = detect(MODEL_DIR)
|
||||
|
||||
cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json")
|
||||
|
||||
if os.path.exists(cfg_path):
|
||||
with open(cfg_path) as f:
|
||||
cfg = json.load(f)
|
||||
else:
|
||||
cfg = {}
|
||||
|
||||
if typ == "fast":
|
||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
||||
elif typ == "sentencepiece":
|
||||
cfg["tokenizer_class"] = "LlamaTokenizer"
|
||||
elif typ == "bpe":
|
||||
cfg["tokenizer_class"] = "GPT2TokenizerFast"
|
||||
else:
|
||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
||||
|
||||
bad_classes = ["TokenizersBackend", "TiktokenTokenizer"]
|
||||
if orig_cls in bad_classes:
|
||||
print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}")
|
||||
|
||||
with open(cfg_path, "w") as f:
|
||||
json.dump(cfg, f)
|
||||
|
||||
print(f"[fix_tokenizer] done → {OUT_DIR}")
|
||||
Reference in New Issue
Block a user