初始化项目,由ModelHub XC社区提供模型
Model: BoscoTheDog/bitnet_b1_58-large_q8_0_gguf Source: Original Platform
This commit is contained in:
133
eval_utils.py
Normal file
133
eval_utils.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import torch
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from lm_eval.base import BaseLM
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
np.random.seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
|
||||
def get_test_dataset(dataset_name, tokenizer, seqlen=2048):
|
||||
if dataset_name == "wikitext2":
|
||||
testdata = load_dataset('wikitext', 'wikitext-2-raw-v1', split='test')
|
||||
testdata = "".join(testdata['text']).split('\n')
|
||||
elif dataset_name == "c4":
|
||||
testdata = load_dataset('allenai/c4', data_files={'validation': 'en/c4-validation.00000-of-00008.json.gz'}, split='validation')['text']
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
testdata = [item for item in testdata if item != ""]
|
||||
tokenized_text = [tokenizer(item, add_special_tokens=False)['input_ids'] + [tokenizer.eos_token_id] for item in testdata]
|
||||
|
||||
data, doc = [], [tokenizer.bos_token_id]
|
||||
for sen in tokenized_text:
|
||||
if len(sen) > seqlen:
|
||||
continue
|
||||
if len(doc) + len(sen) > seqlen:
|
||||
data.append(doc)
|
||||
doc = [tokenizer.bos_token_id]
|
||||
doc.extend(sen)
|
||||
if len(doc) > 1 and len(doc) <= seqlen:
|
||||
data.append(doc)
|
||||
return data
|
||||
|
||||
|
||||
class LMEvalAdaptor(BaseLM):
|
||||
def __init__(self, model_name, model, tokenizer, batch_size=1, max_length=-1):
|
||||
super().__init__()
|
||||
|
||||
assert isinstance(batch_size, int)
|
||||
|
||||
self.model_name = model_name
|
||||
self.model = model
|
||||
self.model.eval()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.vocab_size = self.tokenizer.vocab_size
|
||||
|
||||
self._batch_size = batch_size
|
||||
|
||||
self._max_length = max_length
|
||||
|
||||
@property
|
||||
def eot_token_id(self):
|
||||
# we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@property
|
||||
def max_length(self):
|
||||
if self._max_length != -1:
|
||||
return self._max_length
|
||||
if hasattr(self.model.config, "n_ctx"):
|
||||
return self.model.config.n_ctx
|
||||
elif hasattr(self.model.config, "max_position_embeddings"):
|
||||
return self.model.config.max_position_embeddings
|
||||
elif hasattr(self.model.config, "n_positions"):
|
||||
return self.model.config.n_positions
|
||||
elif "bloom" in self.model_name:
|
||||
return 2048
|
||||
elif "llama" in self.model_name:
|
||||
return 2048 # TODO: did not check this
|
||||
elif "mpt" in self.model_name:
|
||||
return 2048
|
||||
elif "falcon" in self.model_name:
|
||||
return 2048
|
||||
else:
|
||||
print(self.model.config)
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def max_gen_toks(self):
|
||||
return 256
|
||||
|
||||
@property
|
||||
def batch_size(self):
|
||||
return self._batch_size
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return "cuda"
|
||||
|
||||
def tok_encode(self, string: str, add_special_tokens=True):
|
||||
return self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
|
||||
|
||||
def tok_decode(self, tokens):
|
||||
return self.tokenizer.decode(tokens)
|
||||
|
||||
def loglikelihood(self, requests):
|
||||
new_reqs = []
|
||||
for context, continuation in requests:
|
||||
context, continuation = context.strip(), continuation.strip()
|
||||
if context == "":
|
||||
# end of text as context
|
||||
context_enc = [self.eot_token_id]
|
||||
else:
|
||||
context_enc = self.tok_encode(context, add_special_tokens=True)
|
||||
|
||||
continuation_enc = self.tok_encode(continuation, add_special_tokens=False)
|
||||
|
||||
new_reqs.append(((context, continuation), context_enc, continuation_enc))
|
||||
|
||||
return self._loglikelihood_tokens(new_reqs)
|
||||
|
||||
def _model_call(self, inps):
|
||||
"""
|
||||
inps: a torch tensor of shape [batch, sequence]
|
||||
the size of sequence may vary from call to call
|
||||
|
||||
returns: a torch tensor of shape [batch, sequence, vocab] with the
|
||||
logits returned from the model
|
||||
"""
|
||||
with torch.no_grad():
|
||||
out = self.model(inps)[0]
|
||||
return out
|
||||
|
||||
def _model_generate(self, context, max_length, eos_token_id):
|
||||
return self.model.generate(
|
||||
context, max_length=max_length, eos_token_id=eos_token_id, do_sample=False
|
||||
)
|
||||
Reference in New Issue
Block a user