初始化项目,由ModelHub XC社区提供模型
Model: BoscoTheDog/bitnet_b1_58-large_q8_0_gguf Source: Original Platform
This commit is contained in:
67
eval_ppl.py
Normal file
67
eval_ppl.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import math
|
||||
import argparse
|
||||
import torch
|
||||
import random
|
||||
|
||||
from eval_utils import get_test_dataset
|
||||
from .modeling_bitnet import BitnetForCausalLM
|
||||
from .tokenization_bitnet import BitnetTokenizer
|
||||
|
||||
from tqdm import tqdm
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--seed', default=0, type=int)
|
||||
parser.add_argument('--hf_path', default='1bitLLM/bitnet_b1_58-3B', type=str)
|
||||
parser.add_argument('--seqlen', default=2048, type=int)
|
||||
|
||||
|
||||
def calulate_loss(model, input, loss_fct):
|
||||
output = model(input,
|
||||
use_cache=False,
|
||||
output_hidden_states=False,
|
||||
output_attentions=False)[0]
|
||||
shift_logits = output[:, :-1, :].contiguous()
|
||||
shift_labels = input[:, 1:]
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
return loss
|
||||
|
||||
|
||||
def main(args):
|
||||
datasets = ['c4', 'wikitext2']
|
||||
model = BitnetForCausalLM.from_pretrained(
|
||||
args.hf_path,
|
||||
device_map='auto',
|
||||
low_cpu_mem_usage=True,
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.float16,
|
||||
).half()
|
||||
tokenizer = BitnetTokenizer.from_pretrained(args.hf_path, use_fast=False)
|
||||
loss_fct = torch.nn.CrossEntropyLoss(reduction="sum").cuda()
|
||||
|
||||
ppl = []
|
||||
for dataset in datasets:
|
||||
testdata = get_test_dataset(dataset, tokenizer, seqlen=args.seqlen)
|
||||
acc_loss, count = 0.0, 0
|
||||
progress = tqdm(range(len(testdata)))
|
||||
for ii in progress:
|
||||
input = torch.Tensor(testdata[ii]).long().cuda().view(1, -1)
|
||||
loss = calulate_loss(model, input, loss_fct)
|
||||
count += (input.size(-1) - 1)
|
||||
acc_loss += loss.item()
|
||||
progress.set_description(f"avg_loss = {acc_loss/ count / math.log(2)}")
|
||||
|
||||
avg_loss = acc_loss / count / math.log(2)
|
||||
ppl.append(2 ** avg_loss)
|
||||
print("{} PPL: {}".format(dataset, ppl[-1]))
|
||||
|
||||
print(ppl)
|
||||
print("Avg PPL:", sum(ppl) / len(ppl))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.set_grad_enabled(False)
|
||||
args = parser.parse_args()
|
||||
random.seed(args.seed)
|
||||
torch.random.manual_seed(args.seed)
|
||||
main(args)
|
||||
Reference in New Issue
Block a user