Files
2026-02-04 17:22:39 +08:00

70 lines
2.7 KiB
Python

import os
import argparse
from transformers import (AutoModel, AutoModelForCausalLM,
AutoModelForSeq2SeqLM, GenerationConfig)
from vllm.transformers_utils.config import get_config
from utils_internal import (read_model_name, torch_dtype_to_str, str_dtype_to_torch)
from dump_smooth import save_weights
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--hf_model_dir', type=str, default=None)
parser.add_argument('--output_dir',
type=str,
default="output_dir",
help="The path to save the quantized checkpoint")
parser.add_argument('--model_version',
type=str,
default=None,
help="Set model version to replace parsing from _name_or_path in hf config.")
parser.add_argument('--model_type',
type=str,
default=None,
help="Set model type to replace parsing from model_type in hf config."
"if set is None and parsed also None, then set as model_version")
parser.add_argument('--dtype',
type=str,
choices=['auto', 'float32', 'float16', 'bfloat16'],
default='auto',
help="if auto, use unquantized weight torch_dtype in config.json, else use setted dtype")
parser.add_argument(
'--dump_weights',
action="store_true",
default=True,
help='dump weights of the converted model',
)
args = parser.parse_args()
assert args.hf_model_dir, "Please set model_dir by --model_dir or --hf_model_dir"
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
args.model_name, args.model_version, args.model_family, args.model_type = read_model_name(
args.hf_model_dir, args.model_version, args.model_type)
args.hf_config = get_config(args.hf_model_dir, trust_remote_code=True)
if args.dtype == "auto":
args.dtype = torch_dtype_to_str(args.hf_config.torch_dtype)
args.torch_dtype = str_dtype_to_torch(args.dtype)
args.hf_config.torch_dtype = args.torch_dtype
if args.model_name == 'ChatGLMForCausalLM' and args.model_version == 'glm':
auto_model_cls = AutoModelForSeq2SeqLM
elif args.model_name == 'ChatGLMForCausalLM' and args.model_version == 'chatglm':
auto_model_cls = AutoModel
else:
auto_model_cls = AutoModelForCausalLM
model = auto_model_cls.from_pretrained(
args.hf_model_dir,
trust_remote_code=True,
torch_dtype=args.torch_dtype)
named_parameters = dict(model.named_parameters())
save_weights(named_parameters, args)