2026-06-29 17:23:40 +08:00
|
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
"""
|
2026-06-30 13:57:38 +08:00
|
|
|
|
检测并修复 tokenizer_config.json 中的两类问题:
|
|
|
|
|
|
1. tokenizer_class 在 transformers 中不存在(如 TokenizersBackend)
|
|
|
|
|
|
2. extra_special_tokens 为 list 格式(transformers 要求 dict)
|
|
|
|
|
|
|
|
|
|
|
|
若存在问题,将 tokenizer 文件复制到 /tmp/fixed_tokenizer/ 并修复,
|
|
|
|
|
|
最后将修复目录路径输出到 stdout。若无需修复,输出为空。
|
2026-06-29 17:23:40 +08:00
|
|
|
|
"""
|
2026-06-29 17:04:41 +08:00
|
|
|
|
import os
|
2026-06-29 17:23:40 +08:00
|
|
|
|
import sys
|
2026-06-29 17:04:41 +08:00
|
|
|
|
import json
|
2026-06-29 17:23:40 +08:00
|
|
|
|
import shutil
|
2026-06-29 17:04:41 +08:00
|
|
|
|
|
2026-06-29 17:23:40 +08:00
|
|
|
|
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
|
|
|
|
|
|
OUT_DIR = "/tmp/fixed_tokenizer"
|
2026-06-29 17:04:41 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-06-29 17:23:40 +08:00
|
|
|
|
def main():
|
|
|
|
|
|
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
|
|
|
|
|
|
if not os.path.exists(cfg_path):
|
|
|
|
|
|
return
|
2026-06-29 17:04:41 +08:00
|
|
|
|
|
|
|
|
|
|
with open(cfg_path) as f:
|
|
|
|
|
|
cfg = json.load(f)
|
|
|
|
|
|
|
2026-06-30 13:57:38 +08:00
|
|
|
|
fixes = []
|
|
|
|
|
|
|
|
|
|
|
|
# --- 检测 1:tokenizer_class 是否在 transformers 中存在 ---
|
2026-06-29 17:23:40 +08:00
|
|
|
|
tokenizer_class = cfg.get("tokenizer_class", "")
|
2026-06-30 13:57:38 +08:00
|
|
|
|
bad_tokenizer_class = False
|
|
|
|
|
|
if tokenizer_class:
|
|
|
|
|
|
import transformers
|
|
|
|
|
|
if getattr(transformers, tokenizer_class, None) is None:
|
|
|
|
|
|
bad_tokenizer_class = True
|
|
|
|
|
|
fixes.append(f"tokenizer_class '{tokenizer_class}' not found in transformers")
|
2026-06-29 17:23:40 +08:00
|
|
|
|
|
2026-06-30 13:57:38 +08:00
|
|
|
|
# --- 检测 2:extra_special_tokens 是否为 list 格式 ---
|
|
|
|
|
|
bad_extra_special_tokens = (
|
|
|
|
|
|
"extra_special_tokens" in cfg
|
|
|
|
|
|
and isinstance(cfg["extra_special_tokens"], list)
|
2026-06-29 17:23:40 +08:00
|
|
|
|
)
|
2026-06-30 13:57:38 +08:00
|
|
|
|
if bad_extra_special_tokens:
|
|
|
|
|
|
fixes.append("extra_special_tokens is a list, expected dict")
|
2026-06-29 17:23:40 +08:00
|
|
|
|
|
2026-06-30 13:57:38 +08:00
|
|
|
|
if not fixes:
|
|
|
|
|
|
return # 无需修复
|
|
|
|
|
|
|
|
|
|
|
|
# 复制 tokenizer 文件到临时目录
|
2026-06-29 17:23:40 +08:00
|
|
|
|
os.makedirs(OUT_DIR, exist_ok=True)
|
|
|
|
|
|
for fname in [
|
|
|
|
|
|
"tokenizer.json",
|
|
|
|
|
|
"tokenizer_config.json",
|
|
|
|
|
|
"special_tokens_map.json",
|
|
|
|
|
|
"vocab.json",
|
|
|
|
|
|
"merges.txt",
|
|
|
|
|
|
"tokenizer.model",
|
|
|
|
|
|
]:
|
|
|
|
|
|
src = os.path.join(MODEL_DIR, fname)
|
|
|
|
|
|
if os.path.exists(src):
|
|
|
|
|
|
shutil.copy(src, OUT_DIR)
|
|
|
|
|
|
|
2026-06-30 13:57:38 +08:00
|
|
|
|
# --- 修复 1:替换 tokenizer_class ---
|
|
|
|
|
|
if bad_tokenizer_class:
|
|
|
|
|
|
files = os.listdir(MODEL_DIR)
|
|
|
|
|
|
if "tokenizer.json" in files:
|
|
|
|
|
|
fixed_class = "PreTrainedTokenizerFast"
|
|
|
|
|
|
elif "tokenizer.model" in files:
|
|
|
|
|
|
fixed_class = "LlamaTokenizer"
|
|
|
|
|
|
elif "vocab.json" in files and "merges.txt" in files:
|
|
|
|
|
|
fixed_class = "GPT2TokenizerFast"
|
|
|
|
|
|
else:
|
|
|
|
|
|
fixed_class = "PreTrainedTokenizerFast"
|
|
|
|
|
|
cfg["tokenizer_class"] = fixed_class
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"[fix_tokenizer] tokenizer_class: '{tokenizer_class}' → '{fixed_class}'",
|
|
|
|
|
|
file=sys.stderr,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# --- 修复 2:extra_special_tokens list → dict ---
|
|
|
|
|
|
if bad_extra_special_tokens:
|
|
|
|
|
|
orig_list = cfg["extra_special_tokens"]
|
|
|
|
|
|
cfg["extra_special_tokens"] = {token: token for token in orig_list}
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"[fix_tokenizer] extra_special_tokens: list({len(orig_list)}) → dict",
|
|
|
|
|
|
file=sys.stderr,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-06-29 17:23:40 +08:00
|
|
|
|
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
|
|
|
|
|
|
json.dump(cfg, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main()
|