Files
vLLM-Kunlunxin_p-800-tokeni…/fix_tokenizer.py
2026-06-29 17:23:40 +08:00

72 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
检测 tokenizer_config.json 中的 tokenizer_class 是否在 transformers 中存在。
若不存在(如 TokenizersBackend则将 tokenizer 文件复制到 /tmp/fixed_tokenizer/
并修复 tokenizer_class最后将修复目录路径输出到 stdout。
若无需修复,输出为空。
"""
import os
import sys
import json
import shutil
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
OUT_DIR = "/tmp/fixed_tokenizer"
def main():
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
if not os.path.exists(cfg_path):
return
with open(cfg_path) as f:
cfg = json.load(f)
tokenizer_class = cfg.get("tokenizer_class", "")
if not tokenizer_class:
return
# 用 transformers 自身判断该类是否可用,不硬编码类名
import transformers
if getattr(transformers, tokenizer_class, None) is not None:
return # 类存在,无需修复
# tokenizer_class 在 transformers 中不存在,根据实际文件推断正确的类
files = os.listdir(MODEL_DIR)
if "tokenizer.json" in files:
fixed_class = "PreTrainedTokenizerFast"
elif "tokenizer.model" in files:
fixed_class = "LlamaTokenizer"
elif "vocab.json" in files and "merges.txt" in files:
fixed_class = "GPT2TokenizerFast"
else:
fixed_class = "PreTrainedTokenizerFast"
print(
f"[fix_tokenizer] tokenizer_class '{tokenizer_class}' not found in transformers, "
f"replacing with '{fixed_class}'",
file=sys.stderr,
)
os.makedirs(OUT_DIR, exist_ok=True)
for fname in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
src = os.path.join(MODEL_DIR, fname)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
cfg["tokenizer_class"] = fixed_class
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
json.dump(cfg, f, indent=2)
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
main()