import argparse import random from statistics import mean, stdev from typing import List import torch import torchmetrics from datasets import load_dataset from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--model_name_or_path", type=str, default=".", help="Path to the pre-trained model", ) args = parser.parse_args() return args def load_model(model_path: str): # Load the pre-trained model and tokenizer tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=torch.bfloat16) model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda().eval() eos_token_ids = [ token_id for token_id in range(tokenizer.vocab_size) if "\n" in tokenizer.decode([token_id]) ] if hasattr(model.config, "n_positions"): max_length = model.config.n_positions elif hasattr(model.config, "max_position_embeddings"): max_length = model.config.max_position_embeddings elif hasattr(model.config, "max_length"): max_length = model.config.max_length elif hasattr(model.config, "n_ctx"): max_length = model.config.n_ctx else: max_length = 32768 # Default value return { "name": model_path.split("/")[-1], "tokenizer": tokenizer, "model": model, "eos_token_ids": eos_token_ids, "max_length": max_length, } def main(): args = parse_args() model = load_model(args.model_name_or_path) model["model"].save_pretrained( args.model_name_or_path, max_shard_size="5GB" ) if __name__ == "__main__": main()