fix tokenizer
This commit is contained in:
@@ -2,7 +2,6 @@ FROM harbor-contest.4pd.io/sunjichen/xc-llm-kunlun:latest
|
|||||||
|
|
||||||
COPY entrypoint.sh /opt/entrypoint.sh
|
COPY entrypoint.sh /opt/entrypoint.sh
|
||||||
COPY fix_tokenizer.py /opt/fix_tokenizer.py
|
COPY fix_tokenizer.py /opt/fix_tokenizer.py
|
||||||
COPY detect_tokenizer.py /opt/detect_tokenizer.py
|
|
||||||
|
|
||||||
RUN chmod +x /opt/entrypoint.sh
|
RUN chmod +x /opt/entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
import os
|
|
||||||
import json
|
|
||||||
|
|
||||||
def detect(model_dir):
|
|
||||||
cfg_path = os.path.join(model_dir, "tokenizer_config.json")
|
|
||||||
|
|
||||||
if os.path.exists(cfg_path):
|
|
||||||
with open(cfg_path) as f:
|
|
||||||
cfg = json.load(f)
|
|
||||||
cls = cfg.get("tokenizer_class", "")
|
|
||||||
else:
|
|
||||||
cls = ""
|
|
||||||
|
|
||||||
files = os.listdir(model_dir)
|
|
||||||
|
|
||||||
if "tokenizer.json" in files:
|
|
||||||
return "fast", cls
|
|
||||||
|
|
||||||
if "tokenizer.model" in files:
|
|
||||||
return "sentencepiece", cls
|
|
||||||
|
|
||||||
if "vocab.json" in files and "merges.txt" in files:
|
|
||||||
return "bpe", cls
|
|
||||||
|
|
||||||
return "unknown", cls
|
|
||||||
@@ -4,36 +4,11 @@ set -e
|
|||||||
MODEL_DIR=${1:-/model}
|
MODEL_DIR=${1:-/model}
|
||||||
shift || true
|
shift || true
|
||||||
|
|
||||||
FIX_TOKENIZER_DIR=/tmp/fixed_tokenizer
|
FIXED_DIR=$(python3 /opt/fix_tokenizer.py "$MODEL_DIR")
|
||||||
AUTO_FIX=${AUTO_FIX_TOKENIZER:-auto}
|
if [ -n "$FIXED_DIR" ]; then
|
||||||
|
TOKENIZER_ARG="--tokenizer $FIXED_DIR"
|
||||||
echo "[entrypoint] model dir: $MODEL_DIR"
|
|
||||||
|
|
||||||
NEED_FIX=0
|
|
||||||
|
|
||||||
if [ "$AUTO_FIX" = "1" ] || [ "$AUTO_FIX" = "true" ]; then
|
|
||||||
NEED_FIX=1
|
|
||||||
elif [ "$AUTO_FIX" = "auto" ]; then
|
|
||||||
if [ -f "$MODEL_DIR/tokenizer_config.json" ]; then
|
|
||||||
if grep -q "TokenizersBackend\|TiktokenTokenizer" "$MODEL_DIR/tokenizer_config.json"; then
|
|
||||||
NEED_FIX=1
|
|
||||||
fi
|
|
||||||
# 检测 extra_special_tokens 是否为 list 格式
|
|
||||||
if grep -q '"extra_special_tokens":\s*\[' "$MODEL_DIR/tokenizer_config.json"; then
|
|
||||||
NEED_FIX=1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [ $NEED_FIX -eq 1 ]; then
|
|
||||||
echo "[entrypoint] fixing tokenizer..."
|
|
||||||
python3 /opt/fix_tokenizer.py
|
|
||||||
TOKENIZER_ARG="--tokenizer $FIX_TOKENIZER_DIR"
|
|
||||||
else
|
else
|
||||||
echo "[entrypoint] tokenizer OK, skip fix"
|
|
||||||
TOKENIZER_ARG=""
|
TOKENIZER_ARG=""
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "[entrypoint] starting vllm..."
|
|
||||||
|
|
||||||
exec vllm serve "$MODEL_DIR" $TOKENIZER_ARG "$@"
|
exec vllm serve "$MODEL_DIR" $TOKENIZER_ARG "$@"
|
||||||
|
|||||||
108
fix_tokenizer.py
108
fix_tokenizer.py
@@ -1,20 +1,55 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
检测 tokenizer_config.json 中的 tokenizer_class 是否在 transformers 中存在。
|
||||||
|
若不存在(如 TokenizersBackend),则将 tokenizer 文件复制到 /tmp/fixed_tokenizer/
|
||||||
|
并修复 tokenizer_class,最后将修复目录路径输出到 stdout。
|
||||||
|
若无需修复,输出为空。
|
||||||
|
"""
|
||||||
import os
|
import os
|
||||||
import shutil
|
import sys
|
||||||
import json
|
import json
|
||||||
from detect_tokenizer import detect
|
import shutil
|
||||||
|
|
||||||
MODEL_DIR = os.environ.get("MODEL_DIR", "/model")
|
MODEL_DIR = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("MODEL_DIR", "/model")
|
||||||
OUT_DIR = os.environ.get("FIX_TOKENIZER_DIR", "/tmp/fixed_tokenizer")
|
OUT_DIR = "/tmp/fixed_tokenizer"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
cfg_path = os.path.join(MODEL_DIR, "tokenizer_config.json")
|
||||||
|
if not os.path.exists(cfg_path):
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(cfg_path) as f:
|
||||||
|
cfg = json.load(f)
|
||||||
|
|
||||||
|
tokenizer_class = cfg.get("tokenizer_class", "")
|
||||||
|
if not tokenizer_class:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 用 transformers 自身判断该类是否可用,不硬编码类名
|
||||||
|
import transformers
|
||||||
|
if getattr(transformers, tokenizer_class, None) is not None:
|
||||||
|
return # 类存在,无需修复
|
||||||
|
|
||||||
|
# tokenizer_class 在 transformers 中不存在,根据实际文件推断正确的类
|
||||||
|
files = os.listdir(MODEL_DIR)
|
||||||
|
if "tokenizer.json" in files:
|
||||||
|
fixed_class = "PreTrainedTokenizerFast"
|
||||||
|
elif "tokenizer.model" in files:
|
||||||
|
fixed_class = "LlamaTokenizer"
|
||||||
|
elif "vocab.json" in files and "merges.txt" in files:
|
||||||
|
fixed_class = "GPT2TokenizerFast"
|
||||||
|
else:
|
||||||
|
fixed_class = "PreTrainedTokenizerFast"
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[fix_tokenizer] tokenizer_class '{tokenizer_class}' not found in transformers, "
|
||||||
|
f"replacing with '{fixed_class}'",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
|
||||||
os.makedirs(OUT_DIR, exist_ok=True)
|
os.makedirs(OUT_DIR, exist_ok=True)
|
||||||
|
for fname in [
|
||||||
def copy_if_exists(name):
|
|
||||||
src = os.path.join(MODEL_DIR, name)
|
|
||||||
if os.path.exists(src):
|
|
||||||
shutil.copy(src, OUT_DIR)
|
|
||||||
|
|
||||||
# 复制所有可能相关文件
|
|
||||||
for f in [
|
|
||||||
"tokenizer.json",
|
"tokenizer.json",
|
||||||
"tokenizer_config.json",
|
"tokenizer_config.json",
|
||||||
"special_tokens_map.json",
|
"special_tokens_map.json",
|
||||||
@@ -22,48 +57,15 @@ for f in [
|
|||||||
"merges.txt",
|
"merges.txt",
|
||||||
"tokenizer.model",
|
"tokenizer.model",
|
||||||
]:
|
]:
|
||||||
copy_if_exists(f)
|
src = os.path.join(MODEL_DIR, fname)
|
||||||
|
if os.path.exists(src):
|
||||||
|
shutil.copy(src, OUT_DIR)
|
||||||
|
|
||||||
typ, orig_cls = detect(MODEL_DIR)
|
cfg["tokenizer_class"] = fixed_class
|
||||||
|
with open(os.path.join(OUT_DIR, "tokenizer_config.json"), "w") as f:
|
||||||
|
json.dump(cfg, f, indent=2)
|
||||||
|
|
||||||
cfg_path = os.path.join(OUT_DIR, "tokenizer_config.json")
|
print(OUT_DIR) # 输出修复目录,供 entrypoint.sh 捕获
|
||||||
|
|
||||||
if os.path.exists(cfg_path):
|
|
||||||
with open(cfg_path) as f:
|
|
||||||
cfg = json.load(f)
|
|
||||||
else:
|
|
||||||
cfg = {}
|
|
||||||
|
|
||||||
# ===== 自动修复策略 =====
|
main()
|
||||||
if typ == "fast":
|
|
||||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
|
||||||
|
|
||||||
elif typ == "sentencepiece":
|
|
||||||
cfg["tokenizer_class"] = "LlamaTokenizer"
|
|
||||||
|
|
||||||
elif typ == "bpe":
|
|
||||||
cfg["tokenizer_class"] = "GPT2TokenizerFast"
|
|
||||||
|
|
||||||
else:
|
|
||||||
cfg["tokenizer_class"] = "PreTrainedTokenizerFast"
|
|
||||||
|
|
||||||
# 特殊 case 修复
|
|
||||||
bad_classes = [
|
|
||||||
"TokenizersBackend",
|
|
||||||
"TiktokenTokenizer",
|
|
||||||
]
|
|
||||||
|
|
||||||
if orig_cls in bad_classes:
|
|
||||||
print(f"[fix] override bad tokenizer_class: {orig_cls} → {cfg['tokenizer_class']}")
|
|
||||||
|
|
||||||
# 修复 extra_special_tokens: list → dict 格式
|
|
||||||
if "extra_special_tokens" in cfg and isinstance(cfg["extra_special_tokens"], list):
|
|
||||||
orig_list = cfg["extra_special_tokens"]
|
|
||||||
cfg["extra_special_tokens"] = {token: token for token in orig_list}
|
|
||||||
print(f"[fix] converted extra_special_tokens from list ({len(orig_list)} items) to dict format")
|
|
||||||
|
|
||||||
# 写回
|
|
||||||
with open(cfg_path, "w") as f:
|
|
||||||
json.dump(cfg, f)
|
|
||||||
|
|
||||||
print(f"[fix_tokenizer] done → {OUT_DIR}")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user