Files
vLLM-Kunlunxin_p-800-tokeni…/fix_tokenizer.py
2026-06-30 13:57:38 +08:00

96 lines
3.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
检测并修复 tokenizer_config.json 中的两类问题:
1. tokenizer_class 在 transformers 中不存在(如 TokenizersBackend
2. extra_special_tokens 为 list 格式transformers 要求 dict
若存在问题,将 tokenizer 文件复制到 /tmp/fixed_tokenizer/ 并修复,
最后将修复目录路径输出到 stdout。若无需修复输出为空。
"""
import os
import sys
import json
import shutil
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
OUT_DIR = "/tmp/fixed_tokenizer"
def main():
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
if not os.path.exists(cfg_path):
return
with open(cfg_path) as f:
cfg = json.load(f)
fixes = []
# --- 检测 1tokenizer_class 是否在 transformers 中存在 ---
tokenizer_class = cfg.get("tokenizer_class", "")
bad_tokenizer_class = False
if tokenizer_class:
import transformers
if getattr(transformers, tokenizer_class, None) is None:
bad_tokenizer_class = True
fixes.append(f"tokenizer_class '{tokenizer_class}' not found in transformers")
# --- 检测 2extra_special_tokens 是否为 list 格式 ---
bad_extra_special_tokens = (
"extra_special_tokens" in cfg
and isinstance(cfg["extra_special_tokens"], list)
)
if bad_extra_special_tokens:
fixes.append("extra_special_tokens is a list, expected dict")
if not fixes:
return # 无需修复
# 复制 tokenizer 文件到临时目录
os.makedirs(OUT_DIR, exist_ok=True)
for fname in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
src = os.path.join(MODEL_DIR, fname)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
# --- 修复 1替换 tokenizer_class ---
if bad_tokenizer_class:
files = os.listdir(MODEL_DIR)
if "tokenizer.json" in files:
fixed_class = "PreTrainedTokenizerFast"
elif "tokenizer.model" in files:
fixed_class = "LlamaTokenizer"
elif "vocab.json" in files and "merges.txt" in files:
fixed_class = "GPT2TokenizerFast"
else:
fixed_class = "PreTrainedTokenizerFast"
cfg["tokenizer_class"] = fixed_class
print(
f"[fix_tokenizer] tokenizer_class: '{tokenizer_class}''{fixed_class}'",
file=sys.stderr,
)
# --- 修复 2extra_special_tokens list → dict ---
if bad_extra_special_tokens:
orig_list = cfg["extra_special_tokens"]
cfg["extra_special_tokens"] = {token: token for token in orig_list}
print(
f"[fix_tokenizer] extra_special_tokens: list({len(orig_list)}) → dict",
file=sys.stderr,
)
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
main()