fix tokenizer

This commit is contained in:
4paradigm
2026-06-29 17:23:40 +08:00
parent ef6173824e
commit 42420f61ea
4 changed files with 57 additions and 106 deletions

View File

@@ -1,69 +1,71 @@
#!/usr/bin/env python3
"""
检测 tokenizer_config.json 中的 tokenizer_class 是否在 transformers 中存在。
若不存在(如 TokenizersBackend则将 tokenizer 文件复制到 /tmp/fixed_tokenizer/
并修复 tokenizer_class最后将修复目录路径输出到 stdout。
若无需修复,输出为空。
"""
import os
import shutil
import sys
import json
from detect_tokenizer import detect
import shutil
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer")
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
OUT_DIR = "/tmp/fixed_tokenizer"
os.makedirs(OUT_DIR, exist_ok=True)
def copy_if_exists(name):
src = os.path.join(MODEL_DIR, name)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
def main():
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
if not os.path.exists(cfg_path):
return
# 复制所有可能相关文件
for f in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
copy_if_exists(f)
typ, orig_cls = detect(MODEL_DIR)
cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json")
if os.path.exists(cfg_path):
with open(cfg_path) as f:
cfg = json.load(f)
else:
cfg = {}
# ===== 自动修复策略 =====
if typ == "fast":
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
tokenizer_class = cfg.get("tokenizer_class", "")
if not tokenizer_class:
return
elif typ == "sentencepiece":
cfg["tokenizer_class"] = "LlamaTokenizer"
# 用 transformers 自身判断该类是否可用,不硬编码类名
import transformers
if getattr(transformers, tokenizer_class, None) is not None:
return # 类存在,无需修复
elif typ == "bpe":
cfg["tokenizer_class"] = "GPT2TokenizerFast"
# tokenizer_class 在 transformers 中不存在,根据实际文件推断正确的类
files = os.listdir(MODEL_DIR)
if "tokenizer.json" in files:
fixed_class = "PreTrainedTokenizerFast"
elif "tokenizer.model" in files:
fixed_class = "LlamaTokenizer"
elif "vocab.json" in files and "merges.txt" in files:
fixed_class = "GPT2TokenizerFast"
else:
fixed_class = "PreTrainedTokenizerFast"
else:
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
print(
f"[fix_tokenizer] tokenizer_class '{tokenizer_class}' not found in transformers, "
f"replacing with '{fixed_class}'",
file=sys.stderr,
)
# 特殊 case 修复
bad_classes = [
"TokenizersBackend",
"TiktokenTokenizer",
]
os.makedirs(OUT_DIR, exist_ok=True)
for fname in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
src = os.path.join(MODEL_DIR, fname)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
if orig_cls in bad_classes:
print(f"[fix] override bad tokenizer_class: {orig_cls}{cfg['tokenizer_class']}")
cfg["tokenizer_class"] = fixed_class
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
json.dump(cfg, f, indent=2)
# 修复 extra_special_tokens: list → dict 格式
if "extra_special_tokens" in cfg and isinstance(cfg["extra_special_tokens"], list):
orig_list = cfg["extra_special_tokens"]
cfg["extra_special_tokens"] = {token: token for token in orig_list}
print(f"[fix] converted extra_special_tokens from list ({len(orig_list)} items) to dict format")
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
# 写回
with open(cfg_path, "w") as f:
json.dump(cfg, f)
print(f"[fix_tokenizer] done → {OUT_DIR}")
main()