初始化项目,由ModelHub XC社区提供模型
Model: Yhyu13/LMCocktail-10.7B-v1 Source: Original Platform
This commit is contained in:
13
sciprts/convert.py
Normal file
13
sciprts/convert.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from transformers import AutoModelForCausalLM
|
||||
import torch
|
||||
import torch.utils.dlpack
|
||||
|
||||
# Load the original model
|
||||
model_name = "./mixed_llm"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name)
|
||||
|
||||
# Convert the model to a different precision
|
||||
model = model.half()
|
||||
|
||||
# Save the model as a safetensor
|
||||
model.save_pretrained(f"./mixed_llm_half", safetensors=True)
|
||||
2465
sciprts/merge.log
Normal file
2465
sciprts/merge.log
Normal file
File diff suppressed because it is too large
Load Diff
22
sciprts/merge.py
Normal file
22
sciprts/merge.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from LM_Cocktail import mix_models_by_layers
|
||||
import argparse
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default="decoder", help="Type of model to be mixed")
|
||||
parser.add_argument("--output_path", type=str, default="./mixed_llm", help="Path to save the mixed model")
|
||||
parser.add_argument("--max_length", type=int, default=100, help="Maximum length of the sequence to be generated")
|
||||
parser.add_argument("--models", type=str, nargs='+', default=["meta-llama/Llama-2-7b-chat-hf", "Shitao/llama2-ag-news"], help="Path to the models to be mixed")
|
||||
parser.add_argument("--weights", type=float, nargs='+', default=[0.7, 0.3], help="Weights of the models to be mixed")
|
||||
parser.add_argument("--save_precision", type=str, default='float32', help="mixed model saved format")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Mix Large Language Models (LLMs) and save the combined model to the path: ./mixed_llm
|
||||
model = mix_models_by_layers(
|
||||
model_names_or_paths=args.models,
|
||||
model_type=args.model_type,
|
||||
weights=args.weights,
|
||||
output_path=args.output_path)
|
||||
|
||||
print(model)
|
||||
8
sciprts/merge.sh
Executable file
8
sciprts/merge.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/bash
|
||||
eval "$(conda shell.bash hook)"
|
||||
conda activate llmcocktail
|
||||
|
||||
CUDA_VISIBLE_DEVICES=1 python merge.py \
|
||||
--models /media/hangyu5/Home/Documents/Hugging-Face/LM_cocktail/meow /media/hangyu5/Home/Documents/Hugging-Face/LM_cocktail/SOLAR-10.7B-Instruct-v1.0 \
|
||||
--weights 0.5 0.5 \
|
||||
|& tee ./merge.log
|
||||
Reference in New Issue
Block a user