初始化项目,由ModelHub XC社区提供模型
Model: SlimGroove/normistral-11b-translate-mlx Source: Original Platform
This commit is contained in:
68
convert_to_safetensors.py
Normal file
68
convert_to_safetensors.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import argparse
|
||||
import random
|
||||
from statistics import mean, stdev
|
||||
from typing import List
|
||||
import torch
|
||||
import torchmetrics
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
default=".",
|
||||
help="Path to the pre-trained model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def load_model(model_path: str):
|
||||
# Load the pre-trained model and tokenizer
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, torch_dtype=torch.bfloat16)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).cuda().eval()
|
||||
|
||||
eos_token_ids = [
|
||||
token_id
|
||||
for token_id in range(tokenizer.vocab_size)
|
||||
if "\n" in tokenizer.decode([token_id])
|
||||
]
|
||||
|
||||
if hasattr(model.config, "n_positions"):
|
||||
max_length = model.config.n_positions
|
||||
elif hasattr(model.config, "max_position_embeddings"):
|
||||
max_length = model.config.max_position_embeddings
|
||||
elif hasattr(model.config, "max_length"):
|
||||
max_length = model.config.max_length
|
||||
elif hasattr(model.config, "n_ctx"):
|
||||
max_length = model.config.n_ctx
|
||||
else:
|
||||
max_length = 32768 # Default value
|
||||
|
||||
return {
|
||||
"name": model_path.split("/")[-1],
|
||||
"tokenizer": tokenizer,
|
||||
"model": model,
|
||||
"eos_token_ids": eos_token_ids,
|
||||
"max_length": max_length,
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
model = load_model(args.model_name_or_path)
|
||||
|
||||
model["model"].save_pretrained(
|
||||
args.model_name_or_path,
|
||||
max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user