Format Benchmark Code (#399)

This commit is contained in:
Liangsheng Yin
2024-04-28 21:06:22 +08:00
committed by GitHub
parent 19818b9c2f
commit 95c4e0dfac
41 changed files with 1169 additions and 608 deletions

View File

@@ -1,11 +1,15 @@
import argparse
import json
import time
import re
import time
import numpy as np
import sglang as sgl
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
dst_percent = dst_percents[j]
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
query_indices = [q for q in query_indices if
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
query_indices = [
q
for q in query_indices
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
]
dst_index = query_indices[
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
]
label = line_obj["values"][dst_index]
body = line_obj["lines"][:src_index+1]
body = line_obj["lines"][: src_index + 1]
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
body_part_len = len(body) // 4
arguments.append({
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len:]),
"suffix": suffix,
})
arguments.append(
{
"prefix": line_obj["prefix"],
"body_0": "\n".join(body[:body_part_len]),
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
"body_3": "\n".join(body[3 * body_part_len :]),
"suffix": suffix,
}
)
labels.append(label)
sum_src_indices.append(src_index)
sum_dst_indices.append(dst_index)
@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
tic = time.time()
states = line_retrieval.run_batch(
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
arguments,
temperature=0,
backend=backend,
num_threads=args.parallel,
progress_bar=True,
)
latency = time.time() - tic
corrects = []
@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
if response_number == label:
break
correct = (response_number == label)
correct = response_number == label
corrects.append(correct)
# Log results
@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
"other": {
"num_questions": len(arguments),
"parallel": args.parallel,
}
},
}
fout.write(json.dumps(value) + "\n")

View File

@@ -4,12 +4,13 @@ Generate line data for line retrieval task.
Usage:
python3 gen_data.py --number 1000
"""
import argparse
from collections import defaultdict
import json
from tqdm import tqdm
import argparse
import json
from collections import defaultdict
import numpy as np
from tqdm import tqdm
def generate_lines(random_words, num_lines, redirect_ratio):
@@ -42,11 +43,14 @@ def generate_lines(random_words, num_lines, redirect_ratio):
# Add redirect
if redirect_ratio > 0:
num_redirect_lines = int(len(lines) * redirect_ratio)
redirect_indices = np.random.choice(np.arange(len(lines)),
size=(num_redirect_lines,), replace=False)
redirect_indices = np.random.choice(
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
redirects[i] = target_idx
# Build links and find sources