first commit

This commit is contained in:
2026-05-28 10:56:17 +08:00
commit b0b0248cee
5 changed files with 308 additions and 0 deletions

63
fix_tokenizer.py Normal file
View File

@@ -0,0 +1,63 @@
import os
import shutil
import json
from detect_tokenizer import detect
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer")
os.makedirs(OUT_DIR, exist_ok=True)
def copy_if_exists(name):
src = os.path.join(MODEL_DIR, name)
if os.path.exists(src):
shutil.copy(src, OUT_DIR)
# 复制所有可能相关文件
for f in [
"tokenizer.json",
"tokenizer_config.json",
"special_tokens_map.json",
"vocab.json",
"merges.txt",
"tokenizer.model",
]:
copy_if_exists(f)
typ, orig_cls = detect(MODEL_DIR)
cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json")
if os.path.exists(cfg_path):
with open(cfg_path) as f:
cfg = json.load(f)
else:
cfg = {}
# ===== 自动修复策略 =====
if typ == "fast":
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
elif typ == "sentencepiece":
cfg["tokenizer_class"] = "LlamaTokenizer"
elif typ == "bpe":
cfg["tokenizer_class"] = "GPT2TokenizerFast"
else:
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
# 特殊 case 修复
bad_classes = [
"TokenizersBackend",
"TiktokenTokenizer",
]
if orig_cls in bad_classes:
print(f"[fix] override bad tokenizer_class: {orig_cls}{cfg['tokenizer_class']}")
# 写回
with open(cfg_path, "w") as f:
json.dump(cfg, f)
print(f"[fix_tokenizer] done → {OUT_DIR}")