fix tokenizer

This commit is contained in:
4paradigm
2026-06-30 13:57:38 +08:00
parent 42420f61ea
commit 55f8585e15
4 changed files with 80 additions and 62 deletions

View File

@@ -1,9 +1,11 @@
#!/usr/bin/env python3
"""
检测 tokenizer_config.json 中的 tokenizer_class 是否在 transformers 中存在。
若不存在(如 TokenizersBackend则将 tokenizer 文件复制到 /tmp/fixed_tokenizer/
并修复 tokenizer_class最后将修复目录路径输出到 stdout。
若无需修复,输出为空。
检测并修复 tokenizer_config.json 中的两类问题:
1. tokenizer_class 在 transformers 中不存在(如 TokenizersBackend
2. extra_special_tokens 为 list 格式transformers 要求 dict
若存在问题,将 tokenizer 文件复制到 /tmp/fixed_tokenizer/ 并修复,
最后将修复目录路径输出到 stdout。若无需修复输出为空。
"""
import os
import sys
@@ -22,32 +24,29 @@ def main():
with open(cfg_path) as f:
cfg = json.load(f)
fixes = []
# --- 检测 1tokenizer_class 是否在 transformers 中存在 ---
tokenizer_class = cfg.get("tokenizer_class", "")
if not tokenizer_class:
return
bad_tokenizer_class = False
if tokenizer_class:
import transformers
if getattr(transformers, tokenizer_class, None) is None:
bad_tokenizer_class = True
fixes.append(f"tokenizer_class '{tokenizer_class}' not found in transformers")
# 用 transformers 自身判断该类是否可用,不硬编码类名
import transformers
if getattr(transformers, tokenizer_class, None) is not None:
return # 类存在,无需修复
# tokenizer_class 在 transformers 中不存在,根据实际文件推断正确的类
files = os.listdir(MODEL_DIR)
if "tokenizer.json" in files:
fixed_class = "PreTrainedTokenizerFast"
elif "tokenizer.model" in files:
fixed_class = "LlamaTokenizer"
elif "vocab.json" in files and "merges.txt" in files:
fixed_class = "GPT2TokenizerFast"
else:
fixed_class = "PreTrainedTokenizerFast"
print(
f"[fix_tokenizer] tokenizer_class '{tokenizer_class}' not found in transformers, "
f"replacing with '{fixed_class}'",
file=sys.stderr,
# --- 检测 2extra_special_tokens 是否为 list 格式 ---
bad_extra_special_tokens = (
"extra_special_tokens" in cfg
and isinstance(cfg["extra_special_tokens"], list)
)
if bad_extra_special_tokens:
fixes.append("extra_special_tokens is a list, expected dict")
if not fixes:
return # 无需修复
# 复制 tokenizer 文件到临时目录
os.makedirs(OUT_DIR, exist_ok=True)
for fname in [
"tokenizer.json",
@@ -61,7 +60,32 @@ def main():
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
cfg["tokenizer_class"] = fixed_class
# --- 修复 1替换 tokenizer_class ---
if bad_tokenizer_class:
files = os.listdir(MODEL_DIR)
if "tokenizer.json" in files:
fixed_class = "PreTrainedTokenizerFast"
elif "tokenizer.model" in files:
fixed_class = "LlamaTokenizer"
elif "vocab.json" in files and "merges.txt" in files:
fixed_class = "GPT2TokenizerFast"
else:
fixed_class = "PreTrainedTokenizerFast"
cfg["tokenizer_class"] = fixed_class
print(
f"[fix_tokenizer] tokenizer_class: '{tokenizer_class}''{fixed_class}'",
file=sys.stderr,
)
# --- 修复 2extra_special_tokens list → dict ---
if bad_extra_special_tokens:
orig_list = cfg["extra_special_tokens"]
cfg["extra_special_tokens"] = {token: token for token in orig_list}
print(
f"[fix_tokenizer] extra_special_tokens: list({len(orig_list)}) → dict",
file=sys.stderr,
)
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
json.dump(cfg, f, indent=2)