forked from EngineX-Cambricon/enginex-mlu370-vllm
70 lines
2.7 KiB
Python
70 lines
2.7 KiB
Python
import os
|
|
import argparse
|
|
from transformers import (AutoModel, AutoModelForCausalLM,
|
|
AutoModelForSeq2SeqLM, GenerationConfig)
|
|
|
|
from vllm.transformers_utils.config import get_config
|
|
from utils_internal import (read_model_name, torch_dtype_to_str, str_dtype_to_torch)
|
|
from dump_smooth import save_weights
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--hf_model_dir', type=str, default=None)
|
|
parser.add_argument('--output_dir',
|
|
type=str,
|
|
default="output_dir",
|
|
help="The path to save the quantized checkpoint")
|
|
parser.add_argument('--model_version',
|
|
type=str,
|
|
default=None,
|
|
help="Set model version to replace parsing from _name_or_path in hf config.")
|
|
parser.add_argument('--model_type',
|
|
type=str,
|
|
default=None,
|
|
help="Set model type to replace parsing from model_type in hf config."
|
|
"if set is None and parsed also None, then set as model_version")
|
|
parser.add_argument('--dtype',
|
|
type=str,
|
|
choices=['auto', 'float32', 'float16', 'bfloat16'],
|
|
default='auto',
|
|
help="if auto, use unquantized weight torch_dtype in config.json, else use setted dtype")
|
|
parser.add_argument(
|
|
'--dump_weights',
|
|
action="store_true",
|
|
default=True,
|
|
help='dump weights of the converted model',
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
assert args.hf_model_dir, "Please set model_dir by --model_dir or --hf_model_dir"
|
|
|
|
if not os.path.exists(args.output_dir):
|
|
os.makedirs(args.output_dir)
|
|
|
|
args.model_name, args.model_version, args.model_family, args.model_type = read_model_name(
|
|
args.hf_model_dir, args.model_version, args.model_type)
|
|
|
|
args.hf_config = get_config(args.hf_model_dir, trust_remote_code=True)
|
|
|
|
if args.dtype == "auto":
|
|
args.dtype = torch_dtype_to_str(args.hf_config.torch_dtype)
|
|
|
|
args.torch_dtype = str_dtype_to_torch(args.dtype)
|
|
args.hf_config.torch_dtype = args.torch_dtype
|
|
|
|
if args.model_name == 'ChatGLMForCausalLM' and args.model_version == 'glm':
|
|
auto_model_cls = AutoModelForSeq2SeqLM
|
|
elif args.model_name == 'ChatGLMForCausalLM' and args.model_version == 'chatglm':
|
|
auto_model_cls = AutoModel
|
|
else:
|
|
auto_model_cls = AutoModelForCausalLM
|
|
model = auto_model_cls.from_pretrained(
|
|
args.hf_model_dir,
|
|
trust_remote_code=True,
|
|
torch_dtype=args.torch_dtype)
|
|
|
|
named_parameters = dict(model.named_parameters())
|
|
save_weights(named_parameters, args)
|