69 lines
1.7 KiB
Python
69 lines
1.7 KiB
Python
import argparse
|
|
import random
|
|
from statistics import mean, stdev
|
|
from typing import List
|
|
import torch
|
|
import torchmetrics
|
|
from datasets import load_dataset
|
|
from tqdm import tqdm
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--model_name_or_path",
|
|
type=str,
|
|
default=".",
|
|
help="Path to the pre-trained model",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def load_model(model_path: str):
|
|
# Load the pre-trained model and tokenizer
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
|
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda().eval()
|
|
|
|
eos_token_ids = [
|
|
token_id
|
|
for token_id in range(tokenizer.vocab_size)
|
|
if "\n" in tokenizer.decode([token_id])
|
|
]
|
|
|
|
if hasattr(model.config, "n_positions"):
|
|
max_length = model.config.n_positions
|
|
elif hasattr(model.config, "max_position_embeddings"):
|
|
max_length = model.config.max_position_embeddings
|
|
elif hasattr(model.config, "max_length"):
|
|
max_length = model.config.max_length
|
|
elif hasattr(model.config, "n_ctx"):
|
|
max_length = model.config.n_ctx
|
|
else:
|
|
max_length = 32768 # Default value
|
|
|
|
return {
|
|
"name": model_path.split("/")[-1],
|
|
"tokenizer": tokenizer,
|
|
"model": model,
|
|
"eos_token_ids": eos_token_ids,
|
|
"max_length": max_length,
|
|
}
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
model = load_model(args.model_name_or_path)
|
|
|
|
model["model"].save_pretrained(
|
|
args.model_name_or_path,
|
|
max_shard_size="5GB"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|