Format Benchmark Code (#399)

This commit is contained in:
Liangsheng Yin
2024-04-28 21:06:22 +08:00
committed by GitHub
parent 19818b9c2f
commit 95c4e0dfac
41 changed files with 1169 additions and 608 deletions

View File

@@ -1,11 +1,15 @@
import argparse
import json
import time
import re
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [q for q in query_indices if
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
query_indices = [
q
for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index]
body = line_obj["lines"][:src_index+1]
body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append({
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len:]),
"suffix": suffix,
})
arguments.append(
{
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix,
}
)
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
tic = time.time()
states = line_retrieval.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
corrects = []
@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
if response_number == label:
break
correct = (response_number == label)
correct = response_number == label
corrects.append(correct)
# Log results
@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
}
},
}
fout.write(json.dumps(value) + "\n")