初始化项目,由ModelHub XC社区提供模型
Model: Finnish-NLP/Ahma-7B Source: Original Platform
This commit is contained in:
52
EasyLM/scripts/lm_eval_json.py
Normal file
52
EasyLM/scripts/lm_eval_json.py
Normal file
@@ -0,0 +1,52 @@
|
||||
import json
|
||||
import mlxu
|
||||
from EasyLM.serving import LMClient
|
||||
|
||||
|
||||
FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
|
||||
input_file='',
|
||||
output_file='',
|
||||
prefix_field='prefix',
|
||||
text_field='text',
|
||||
until_field='until',
|
||||
eval_type='loglikelihood',
|
||||
lm_client=LMClient.get_default_config(),
|
||||
)
|
||||
|
||||
|
||||
def main(argv):
|
||||
lm_client = LMClient(FLAGS.lm_client)
|
||||
with mlxu.open_file(FLAGS.input_file, 'r') as fin:
|
||||
input_data = json.load(fin)
|
||||
|
||||
if FLAGS.eval_type == 'loglikelihood':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'loglikelihood_rolling':
|
||||
text = input_data[FLAGS.text_field]
|
||||
loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
|
||||
output_data = {
|
||||
'loglikelihood': loglikelihoods,
|
||||
'is_greedy': is_greedys,
|
||||
}
|
||||
elif FLAGS.eval_type == 'greedy_until':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
until = input_data[FLAGS.until_field]
|
||||
output_data = {'output_text': lm_client.greedy_until(prefix, until)}
|
||||
elif FLAGS.eval_type == 'generate':
|
||||
prefix = input_data[FLAGS.prefix_field]
|
||||
output_data = {'output_text': lm_client.generate(prefix)}
|
||||
else:
|
||||
raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
|
||||
|
||||
with mlxu.open_file(FLAGS.output_file, 'w') as fout:
|
||||
json.dump(output_data, fout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mlxu.run(main)
|
||||
Reference in New Issue
Block a user