release initial code
Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Liangsheng Yin <hnyls2002@gmail.com> Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu> Co-authored-by: parasol-aser <3848358+parasol-aser@users.noreply.github.com> Co-authored-by: LiviaSun <33578456+ChuyueSun@users.noreply.github.com> Co-authored-by: Cody Yu <hao.yu.cody@gmail.com>
This commit is contained in:
37
benchmark/line_retrieval/README.md
Normal file
37
benchmark/line_retrieval/README.md
Normal file
@@ -0,0 +1,37 @@
|
||||
## Download data
|
||||
|
||||
```
|
||||
wget https://raw.githubusercontent.com/merrymercy/merrymercy.github.io/master/files/random_words.json
|
||||
python3 gen_data.py --number 1000
|
||||
```
|
||||
|
||||
## Run benchmark
|
||||
|
||||
### Benchmark sglang
|
||||
```
|
||||
python3 -m sglang.launch_server --model-path codellama/CodeLlama-7b-hf --port 30000
|
||||
```
|
||||
|
||||
```
|
||||
python3 bench_sglang.py --src-index 600 --num-q 50 --parallel 1
|
||||
```
|
||||
|
||||
|
||||
###
|
||||
|
||||
```
|
||||
# original
|
||||
Accuracy: 0.940, latency: 332.83 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 1000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 3000)
|
||||
Accuracy: 0.760, latency: 238.46 s
|
||||
|
||||
# parallel encoding (no_adjust, offset = 0)
|
||||
Accuracy: 0.520, latency: 238.46 s
|
||||
|
||||
# parallel encoding (adjust_cache)
|
||||
Accuracy: 0.460, latency: 257.66 s
|
||||
```
|
||||
133
benchmark/line_retrieval/bench_sglang.py
Normal file
133
benchmark/line_retrieval/bench_sglang.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.utils import dump_state_text
|
||||
|
||||
|
||||
@sgl.function
|
||||
def line_retrieval(s, prefix, suffix, body_0, body_1, body_2, body_3):
|
||||
s += prefix + "\n"
|
||||
|
||||
contexts = [body_0, body_1, body_2, body_3]
|
||||
position_ids_offset = [i * 1000 for i in range(len(contexts))]
|
||||
forks = s.fork(len(contexts), position_ids_offset)
|
||||
forks += lambda i: contexts[i] + "\n"
|
||||
forks.join(mode="concate_and_append")
|
||||
|
||||
s += "\n" + suffix
|
||||
s += sgl.gen("answer", max_tokens=16)
|
||||
|
||||
|
||||
def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
||||
arguments = []
|
||||
labels = []
|
||||
sum_src_indices = []
|
||||
sum_dst_indices = []
|
||||
|
||||
for i in range(len(src_indices)):
|
||||
for j in range(len(dst_percents)):
|
||||
src_index = src_indices[i]
|
||||
dst_percent = dst_percents[j]
|
||||
|
||||
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||
query_indices = [q for q in query_indices if
|
||||
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
|
||||
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
|
||||
label = line_obj["values"][dst_index]
|
||||
|
||||
body = line_obj["lines"][:src_index+1]
|
||||
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||
body_part_len = len(body) // 4
|
||||
|
||||
arguments.append({
|
||||
"prefix": line_obj["prefix"],
|
||||
"body_0": "\n".join(body[:body_part_len]),
|
||||
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
|
||||
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
|
||||
"body_3": "\n".join(body[3 * body_part_len:]),
|
||||
"suffix": suffix,
|
||||
})
|
||||
labels.append(label)
|
||||
sum_src_indices.append(src_index)
|
||||
sum_dst_indices.append(dst_index)
|
||||
|
||||
# Select backend
|
||||
backend = select_sglang_backend(args)
|
||||
|
||||
tic = time.time()
|
||||
states = line_retrieval.run_batch(
|
||||
arguments, temperature=0, backend=backend, num_threads=args.parallel)
|
||||
latency = time.time() - tic
|
||||
|
||||
corrects = []
|
||||
for i in range(len(arguments)):
|
||||
output = states[i]["answer"]
|
||||
prompt_len = states[i].get_meta_info("answer").get("prompt_length", -1)
|
||||
label = labels[i]
|
||||
|
||||
# Try all numbers
|
||||
findall = re.findall("\d+", output)
|
||||
if not findall:
|
||||
response_number = output
|
||||
else:
|
||||
for response_number in findall:
|
||||
if response_number == label:
|
||||
break
|
||||
|
||||
correct = (response_number == label)
|
||||
corrects.append(correct)
|
||||
|
||||
# Log results
|
||||
summary = (
|
||||
f"Line index: {sum_src_indices[i]} -> {sum_dst_indices[i]}, "
|
||||
f"Prompt len: {prompt_len}, "
|
||||
f"Correct: {correct}, "
|
||||
f"Label: {label}, Predicted: {response_number}, "
|
||||
)
|
||||
print(summary)
|
||||
|
||||
accuracy = np.mean(corrects)
|
||||
print(f"Accuracy: {accuracy:.3f}, latency: {latency:.2f} s")
|
||||
|
||||
# Write results
|
||||
dump_state_text(f"tmp_output_{args.backend}.txt", states)
|
||||
|
||||
with open(args.result_file, "a") as fout:
|
||||
value = {
|
||||
"task": "line_retrieval",
|
||||
"backend": args.backend,
|
||||
"num_gpus": 1,
|
||||
"latency": round(latency, 3),
|
||||
"num_requests": len(arguments),
|
||||
"other": {
|
||||
"num_questions": len(arguments),
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
line_obj = json.load(open(args.data_path, "r"))
|
||||
|
||||
num_hoops = args.num_hoops
|
||||
for src_index in args.src_index:
|
||||
src_indices = [src_index]
|
||||
num_queries = args.num_queries_per_src
|
||||
dst_percents = [i * (1 / (num_queries)) for i in range(num_queries)]
|
||||
eval_model(args, line_obj, num_hoops, src_indices, dst_percents)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--data-path", type=str, default="lines_1000_0.0.json")
|
||||
parser.add_argument("--src-index", type=int, nargs="+", default=[100])
|
||||
parser.add_argument("--num-queries-per-src", type=int, default=10)
|
||||
parser.add_argument("--num-hoops", type=int, default=1)
|
||||
args = add_common_sglang_args_and_parse(parser)
|
||||
main(args)
|
||||
135
benchmark/line_retrieval/gen_data.py
Normal file
135
benchmark/line_retrieval/gen_data.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""
|
||||
Generate line data for line retrieval task.
|
||||
|
||||
Usage:
|
||||
python3 gen_data.py --number 1000
|
||||
"""
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
import json
|
||||
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
|
||||
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||
prefix = "Here is a list of lines, each with its corresponding REGISTER_CONTENT value. Please memorize them. Be prepared to provide the REGISTER_CONTENT value for a specific line index when I ask."
|
||||
suffix = "The list has ended. Please give the final REGISTER_CONTENT value for a specific line after resovling the redirections and references. For example, the REGISTER_CONTENT of Line __idx0__ is __val0__. The REGISTER_CONTENT of Line __idx1__ is __val1__. The REGISTER_CONTENT of Line __idx2__ is __val2__. The REGISTER_CONTENT of Line ??? is"
|
||||
|
||||
# Raw lines
|
||||
visited_indices = set([None])
|
||||
visited_values = set([None])
|
||||
|
||||
lines = []
|
||||
redirects = []
|
||||
indices = []
|
||||
values = []
|
||||
for i in tqdm(range(num_lines)):
|
||||
line_index = None
|
||||
while line_index in visited_indices:
|
||||
line_index = "-".join(np.random.choice(random_words, size=(2,)))
|
||||
visited_indices.add(line_index)
|
||||
|
||||
line_value = np.random.randint(low=0, high=999999)
|
||||
line_value = f"{line_value:06}"
|
||||
|
||||
line = f"Line {line_index}: The REGISTER_CONTENT is {line_value}."
|
||||
lines.append(line)
|
||||
redirects.append(None)
|
||||
indices.append(line_index)
|
||||
values.append(line_value)
|
||||
|
||||
# Add redirect
|
||||
if redirect_ratio > 0:
|
||||
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||
redirect_indices = np.random.choice(np.arange(len(lines)),
|
||||
size=(num_redirect_lines,), replace=False)
|
||||
for i in redirect_indices:
|
||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||
redirects[i] = target_idx
|
||||
|
||||
# Build links and find sources
|
||||
links = [[] for _ in range(num_lines)]
|
||||
contains_ring = set()
|
||||
for i in range(num_lines):
|
||||
if redirects[i] is None:
|
||||
continue
|
||||
|
||||
tmp_link = []
|
||||
cur = i
|
||||
visited = set()
|
||||
while redirects[cur] is not None:
|
||||
visited.add(cur)
|
||||
tmp_link.append(redirects[cur])
|
||||
cur = redirects[cur]
|
||||
|
||||
if cur in visited:
|
||||
contains_ring.add(i)
|
||||
tmp_link = None
|
||||
break
|
||||
values[i] = values[cur]
|
||||
links[i] = tmp_link
|
||||
|
||||
# Group by num_links
|
||||
group_by_num_hoops = defaultdict(list)
|
||||
for i in range(num_lines):
|
||||
if i in contains_ring:
|
||||
continue
|
||||
group_by_num_hoops[len(links[i]) + 1].append(i)
|
||||
|
||||
keys = sorted(list(group_by_num_hoops.keys()))
|
||||
for num_links in keys:
|
||||
print(f"#links: {num_links}, #lines: {len(group_by_num_hoops[num_links])}")
|
||||
|
||||
# Append few-shot examples
|
||||
hoop1_candidates = list(group_by_num_hoops[1])
|
||||
hoop1_candidate_keys = {c: max([c] + links[c]) for c in hoop1_candidates}
|
||||
hoop1_candidates.sort(key=lambda c: hoop1_candidate_keys[c])
|
||||
hoop2_candidates = list(group_by_num_hoops[2])
|
||||
hoop2_candidate_keys = {c: max([c] + links[c]) for c in hoop2_candidates}
|
||||
hoop2_candidates.sort(key=lambda c: hoop2_candidate_keys[c])
|
||||
|
||||
i = hoop1_candidates[5]
|
||||
suffix = suffix.replace("__idx0__", indices[i]).replace("__val0__", values[i])
|
||||
if len(hoop2_candidates):
|
||||
i = hoop2_candidates[0]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop2_candidates[1]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
else:
|
||||
i = hoop1_candidates[1]
|
||||
suffix = suffix.replace("__idx1__", indices[i]).replace("__val1__", values[i])
|
||||
i = hoop1_candidates[10]
|
||||
suffix = suffix.replace("__idx2__", indices[i]).replace("__val2__", values[i])
|
||||
|
||||
obj = {
|
||||
"prefix": prefix,
|
||||
"suffix": suffix,
|
||||
"lines": lines,
|
||||
"indices": indices,
|
||||
"values": values,
|
||||
"links": links,
|
||||
"group_by_num_hoops": group_by_num_hoops,
|
||||
"contains_ring": sorted(list(contains_ring)),
|
||||
}
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--number", type=int)
|
||||
parser.add_argument("--redirect-ratio", type=float, default=0.0)
|
||||
args = parser.parse_args()
|
||||
|
||||
num_lines = args.number
|
||||
|
||||
random_words_filename = "random_words.json"
|
||||
random_words = json.load(open(random_words_filename, "r"))
|
||||
|
||||
np.random.seed(42)
|
||||
obj = generate_lines(random_words, num_lines, args.redirect_ratio)
|
||||
|
||||
fout = f"lines_{num_lines}_{args.redirect_ratio:.1f}.json"
|
||||
with open(fout, "w") as fout:
|
||||
json.dump(obj, fout, indent=2)
|
||||
Reference in New Issue
Block a user