72 lines
2.1 KiB
Python
72 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
检测 tokenizer_config.json 中的 tokenizer_class 是否在 transformers 中存在。
|
||
若不存在(如 TokenizersBackend),则将 tokenizer 文件复制到 /tmp/fixed_tokenizer/
|
||
并修复 tokenizer_class,最后将修复目录路径输出到 stdout。
|
||
若无需修复,输出为空。
|
||
"""
|
||
import os
|
||
import sys
|
||
import json
|
||
import shutil
|
||
|
||
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
|
||
OUT_DIR = "/tmp/fixed_tokenizer"
|
||
|
||
|
||
def main():
|
||
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
|
||
if not os.path.exists(cfg_path):
|
||
return
|
||
|
||
with open(cfg_path) as f:
|
||
cfg = json.load(f)
|
||
|
||
tokenizer_class = cfg.get("tokenizer_class", "")
|
||
if not tokenizer_class:
|
||
return
|
||
|
||
# 用 transformers 自身判断该类是否可用,不硬编码类名
|
||
import transformers
|
||
if getattr(transformers, tokenizer_class, None) is not None:
|
||
return # 类存在,无需修复
|
||
|
||
# tokenizer_class 在 transformers 中不存在,根据实际文件推断正确的类
|
||
files = os.listdir(MODEL_DIR)
|
||
if "tokenizer.json" in files:
|
||
fixed_class = "PreTrainedTokenizerFast"
|
||
elif "tokenizer.model" in files:
|
||
fixed_class = "LlamaTokenizer"
|
||
elif "vocab.json" in files and "merges.txt" in files:
|
||
fixed_class = "GPT2TokenizerFast"
|
||
else:
|
||
fixed_class = "PreTrainedTokenizerFast"
|
||
|
||
print(
|
||
f"[fix_tokenizer] tokenizer_class '{tokenizer_class}' not found in transformers, "
|
||
f"replacing with '{fixed_class}'",
|
||
file=sys.stderr,
|
||
)
|
||
|
||
os.makedirs(OUT_DIR, exist_ok=True)
|
||
for fname in [
|
||
"tokenizer.json",
|
||
"tokenizer_config.json",
|
||
"special_tokens_map.json",
|
||
"vocab.json",
|
||
"merges.txt",
|
||
"tokenizer.model",
|
||
]:
|
||
src = os.path.join(MODEL_DIR, fname)
|
||
if os.path.exists(src):
|
||
shutil.copy(src, OUT_DIR)
|
||
|
||
cfg["tokenizer_class"] = fixed_class
|
||
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
|
||
json.dump(cfg, f, indent=2)
|
||
|
||
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
|
||
|
||
|
||
main()
|