Format Benchmark Code (#399)
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
dst_percent = dst_percents[j]
|
||||
|
||||
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||
query_indices = [q for q in query_indices if
|
||||
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
|
||||
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
|
||||
query_indices = [
|
||||
q
|
||||
for q in query_indices
|
||||
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
|
||||
]
|
||||
dst_index = query_indices[
|
||||
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
|
||||
]
|
||||
label = line_obj["values"][dst_index]
|
||||
|
||||
body = line_obj["lines"][:src_index+1]
|
||||
body = line_obj["lines"][: src_index + 1]
|
||||
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||
body_part_len = len(body) // 4
|
||||
|
||||
arguments.append({
|
||||
"prefix": line_obj["prefix"],
|
||||
"body_0": "\n".join(body[:body_part_len]),
|
||||
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
|
||||
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
|
||||
"body_3": "\n".join(body[3 * body_part_len:]),
|
||||
"suffix": suffix,
|
||||
})
|
||||
arguments.append(
|
||||
{
|
||||
"prefix": line_obj["prefix"],
|
||||
"body_0": "\n".join(body[:body_part_len]),
|
||||
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
|
||||
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
|
||||
"body_3": "\n".join(body[3 * body_part_len :]),
|
||||
"suffix": suffix,
|
||||
}
|
||||
)
|
||||
labels.append(label)
|
||||
sum_src_indices.append(src_index)
|
||||
sum_dst_indices.append(dst_index)
|
||||
@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
|
||||
tic = time.time()
|
||||
states = line_retrieval.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
||||
arguments,
|
||||
temperature=0,
|
||||
backend=backend,
|
||||
num_threads=args.parallel,
|
||||
progress_bar=True,
|
||||
)
|
||||
latency = time.time() - tic
|
||||
|
||||
corrects = []
|
||||
@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
if response_number == label:
|
||||
break
|
||||
|
||||
correct = (response_number == label)
|
||||
correct = response_number == label
|
||||
corrects.append(correct)
|
||||
|
||||
# Log results
|
||||
@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
"other": {
|
||||
"num_questions": len(arguments),
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
@@ -4,12 +4,13 @@ Generate line data for line retrieval task.
|
||||
Usage:
|
||||
python3 gen_data.py --number 1000
|
||||
"""
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import argparse
|
||||
import json
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||
@@ -42,11 +43,14 @@ def generate_lines(random_words, num_lines, redirect_ratio):
|
||||
# Add redirect
|
||||
if redirect_ratio > 0:
|
||||
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||
redirect_indices = np.random.choice(np.arange(len(lines)),
|
||||
size=(num_redirect_lines,), replace=False)
|
||||
redirect_indices = np.random.choice(
|
||||
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
||||
)
|
||||
for i in redirect_indices:
|
||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||
lines[i] = (
|
||||
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||
)
|
||||
redirects[i] = target_idx
|
||||
|
||||
# Build links and find sources
|
||||
|
||||
Reference in New Issue
Block a user