Format Benchmark Code (#399)
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
Adapted from
|
Adapted from
|
||||||
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
https://github.com/stanfordnlp/dspy/blob/34d8420383ec752037aa271825c1d3bf391e1277/intro.ipynb#L9
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
import dspy
|
import dspy
|
||||||
@@ -37,29 +38,41 @@ class RAG(dspy.Module):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
#lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
# lm = dspy.OpenAI(model='gpt-3.5-turbo')
|
||||||
if args.backend == "tgi":
|
if args.backend == "tgi":
|
||||||
lm = dspy.HFClientTGI(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
lm = dspy.HFClientTGI(
|
||||||
url="http://localhost")
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
elif args.backend == "sglang":
|
elif args.backend == "sglang":
|
||||||
lm = dspy.HFClientSGLang(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
lm = dspy.HFClientSGLang(
|
||||||
url="http://localhost")
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
elif args.backend == "vllm":
|
elif args.backend == "vllm":
|
||||||
lm = dspy.HFClientVLLM(model="meta-llama/Llama-2-7b-chat-hf", port=args.port,
|
lm = dspy.HFClientVLLM(
|
||||||
url="http://localhost")
|
model="meta-llama/Llama-2-7b-chat-hf",
|
||||||
|
port=args.port,
|
||||||
|
url="http://localhost",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid backend: {args.backend}")
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
colbertv2_wiki17_abstracts = dspy.ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts')
|
colbertv2_wiki17_abstracts = dspy.ColBERTv2(
|
||||||
|
url="http://20.102.90.50:2017/wiki17_abstracts"
|
||||||
|
)
|
||||||
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
dspy.settings.configure(lm=lm, rm=colbertv2_wiki17_abstracts)
|
||||||
|
|
||||||
# Load the dataset.
|
# Load the dataset.
|
||||||
dataset = HotPotQA(train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size,
|
dataset = HotPotQA(
|
||||||
test_size=0)
|
train_seed=1, train_size=20, eval_seed=2023, dev_size=args.dev_size, test_size=0
|
||||||
|
)
|
||||||
|
|
||||||
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
# Tell DSPy that the 'question' field is the input. Any other fields are labels and/or metadata.
|
||||||
trainset = [x.with_inputs('question') for x in dataset.train]
|
trainset = [x.with_inputs("question") for x in dataset.train]
|
||||||
devset = [x.with_inputs('question') for x in dataset.dev]
|
devset = [x.with_inputs("question") for x in dataset.dev]
|
||||||
|
|
||||||
print(len(trainset), len(devset))
|
print(len(trainset), len(devset))
|
||||||
|
|
||||||
@@ -72,8 +85,12 @@ def main(args):
|
|||||||
print(f"Answer: {dev_example.answer}")
|
print(f"Answer: {dev_example.answer}")
|
||||||
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
print(f"Relevant Wikipedia Titles: {dev_example.gold_titles}")
|
||||||
|
|
||||||
print(f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}")
|
print(
|
||||||
print(f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}")
|
f"For this dataset, training examples have input keys {train_example.inputs().keys()} and label keys {train_example.labels().keys()}"
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
f"For this dataset, dev examples have input keys {dev_example.inputs().keys()} and label keys {dev_example.labels().keys()}"
|
||||||
|
)
|
||||||
|
|
||||||
# Define the predictor.
|
# Define the predictor.
|
||||||
generate_answer = dspy.Predict(BasicQA)
|
generate_answer = dspy.Predict(BasicQA)
|
||||||
@@ -101,10 +118,14 @@ def main(args):
|
|||||||
retrieve = dspy.Retrieve(k=3)
|
retrieve = dspy.Retrieve(k=3)
|
||||||
topK_passages = retrieve(dev_example.question).passages
|
topK_passages = retrieve(dev_example.question).passages
|
||||||
|
|
||||||
print(f"Top {retrieve.k} passages for question: {dev_example.question} \n", '-' * 30, '\n')
|
print(
|
||||||
|
f"Top {retrieve.k} passages for question: {dev_example.question} \n",
|
||||||
|
"-" * 30,
|
||||||
|
"\n",
|
||||||
|
)
|
||||||
|
|
||||||
for idx, passage in enumerate(topK_passages):
|
for idx, passage in enumerate(topK_passages):
|
||||||
print(f'{idx+1}]', passage, '\n')
|
print(f"{idx+1}]", passage, "\n")
|
||||||
|
|
||||||
retrieve("When was the first FIFA World Cup held?").passages[0]
|
retrieve("When was the first FIFA World Cup held?").passages[0]
|
||||||
|
|
||||||
@@ -137,7 +158,12 @@ def main(args):
|
|||||||
from dspy.evaluate.evaluate import Evaluate
|
from dspy.evaluate.evaluate import Evaluate
|
||||||
|
|
||||||
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
# Set up the `evaluate_on_hotpotqa` function. We'll use this many times below.
|
||||||
evaluate_on_hotpotqa = Evaluate(devset=devset, num_threads=args.num_threads, display_progress=True, display_table=5)
|
evaluate_on_hotpotqa = Evaluate(
|
||||||
|
devset=devset,
|
||||||
|
num_threads=args.num_threads,
|
||||||
|
display_progress=True,
|
||||||
|
display_table=5,
|
||||||
|
)
|
||||||
|
|
||||||
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
# Evaluate the `compiled_rag` program with the `answer_exact_match` metric.
|
||||||
metric = dspy.evaluate.answer_exact_match
|
metric = dspy.evaluate.answer_exact_match
|
||||||
@@ -149,8 +175,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--port", type=int)
|
parser.add_argument("--port", type=int)
|
||||||
parser.add_argument("--num-threads", type=int, default=32)
|
parser.add_argument("--num-threads", type=int, default=32)
|
||||||
parser.add_argument("--dev-size", type=int, default=150)
|
parser.add_argument("--dev-size", type=int, default=150)
|
||||||
parser.add_argument("--backend", type=str, choices=["sglang", "tgi", "vllm"],
|
parser.add_argument(
|
||||||
default="sglang")
|
"--backend", type=str, choices=["sglang", "tgi", "vllm"], default="sglang"
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.port is None:
|
if args.port is None:
|
||||||
|
|||||||
@@ -122,16 +122,36 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs
|
|||||||
* Must be one of the "Area options," verbatim.
|
* Must be one of the "Area options," verbatim.
|
||||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||||
---"""
|
---"""
|
||||||
s += (persona_name + " lives in " + living_sector + " that has " +
|
s += (
|
||||||
living_sector_areas + ".\n")
|
persona_name
|
||||||
s += (persona_name + " is currently in " + current_sector + " that has " +
|
+ " lives in "
|
||||||
current_sector_areas + ".\n")
|
+ living_sector
|
||||||
|
+ " that has "
|
||||||
|
+ living_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is currently in "
|
||||||
|
+ current_sector
|
||||||
|
+ " that has "
|
||||||
|
+ current_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
s += daily_plan + ".\n"
|
s += daily_plan + ".\n"
|
||||||
s += "Area options: " + sector_options + ".\n"
|
s += "Area options: " + sector_options + ".\n"
|
||||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
* Must be one of the "Area options," verbatim.\n"""
|
* Must be one of the "Area options," verbatim.\n"""
|
||||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
s += (
|
||||||
", " + persona_name + " should go to the following area: {")
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ " should go to the following area: {"
|
||||||
|
)
|
||||||
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
s += sgl.gen(name="Location", max_tokens=10, stop="}")
|
||||||
|
|
||||||
|
|
||||||
@@ -162,22 +182,43 @@ Area options: {Oak Hill College Student Dormatory, The Rose and Crown Pub, Hobbs
|
|||||||
* Must be one of the "Area options," verbatim.
|
* Must be one of the "Area options," verbatim.
|
||||||
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
For eating dinner, Jane Anderson should go to the following area: {Hobbs Cafe}
|
||||||
---"""
|
---"""
|
||||||
s += (persona_name + " lives in " + living_sector + " that has " +
|
s += (
|
||||||
living_sector_areas + ".\n")
|
persona_name
|
||||||
s += (persona_name + " is currently in " + current_sector + " that has " +
|
+ " lives in "
|
||||||
current_sector_areas + ".\n")
|
+ living_sector
|
||||||
|
+ " that has "
|
||||||
|
+ living_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
persona_name
|
||||||
|
+ " is currently in "
|
||||||
|
+ current_sector
|
||||||
|
+ " that has "
|
||||||
|
+ current_sector_areas
|
||||||
|
+ ".\n"
|
||||||
|
)
|
||||||
s += daily_plan + ".\n"
|
s += daily_plan + ".\n"
|
||||||
s += "Area options: " + sector_options + ".\n"
|
s += "Area options: " + sector_options + ".\n"
|
||||||
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
s += """* Stay in the current area if the activity can be done there. Only go out if the activity needs to take place in another place.
|
||||||
* Must be one of the "Area options," verbatim.\n"""
|
* Must be one of the "Area options," verbatim.\n"""
|
||||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
s += (
|
||||||
", " + persona_name + " should go to the following area: {")
|
persona_name
|
||||||
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ " should go to the following area: {"
|
||||||
|
)
|
||||||
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
return {"prompt": s, "max_tokens": 10, "stop": "}"}
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def action_location_object(s, persona_name, target_sector, target_sector_areas,
|
def action_location_object(
|
||||||
current_action, next_action):
|
s, persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||||
|
):
|
||||||
s += """
|
s += """
|
||||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||||
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
Jane Anderson is going to Jane Anderson's house that has the following areas: {kitchen, bedroom, bathroom}
|
||||||
@@ -191,20 +232,34 @@ Stay in the current area if the activity can be done there. Never go into other
|
|||||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||||
Answer: {cafe}
|
Answer: {cafe}
|
||||||
---"""
|
---"""
|
||||||
s += (persona_name + " is going to " + target_sector +
|
s += (
|
||||||
" that has the following areas: {" + target_sector_areas + "}\n")
|
persona_name
|
||||||
|
+ " is going to "
|
||||||
|
+ target_sector
|
||||||
|
+ " that has the following areas: {"
|
||||||
|
+ target_sector_areas
|
||||||
|
+ "}\n"
|
||||||
|
)
|
||||||
s += """* Stay in the current area if the activity can be done there.
|
s += """* Stay in the current area if the activity can be done there.
|
||||||
* NEVER go into other people's rooms unless necessary."""
|
* NEVER go into other people's rooms unless necessary."""
|
||||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
s += (
|
||||||
", " + persona_name + "should go to the following area in " +
|
persona_name
|
||||||
target_sector)
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ "should go to the following area in "
|
||||||
|
+ target_sector
|
||||||
|
)
|
||||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||||
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
s += "Answer: {" + sgl.gen(name="Area", max_tokens=5, stop="}")
|
||||||
|
|
||||||
|
|
||||||
def action_location_object_prompt(persona_name, target_sector,
|
def action_location_object_prompt(
|
||||||
target_sector_areas, current_action,
|
persona_name, target_sector, target_sector_areas, current_action, next_action
|
||||||
next_action):
|
):
|
||||||
s = ""
|
s = ""
|
||||||
s += """
|
s += """
|
||||||
Jane Anderson is in kitchen in Jane Anderson's house.
|
Jane Anderson is in kitchen in Jane Anderson's house.
|
||||||
@@ -219,13 +274,27 @@ Stay in the current area if the activity can be done there. Never go into other
|
|||||||
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
For getting coffee, Tom Watson should go to the following area in Hobbs Cafe:
|
||||||
Answer: {cafe}
|
Answer: {cafe}
|
||||||
---"""
|
---"""
|
||||||
s += (persona_name + " is going to " + target_sector +
|
s += (
|
||||||
" that has the following areas: {" + target_sector_areas + "}\n")
|
persona_name
|
||||||
|
+ " is going to "
|
||||||
|
+ target_sector
|
||||||
|
+ " that has the following areas: {"
|
||||||
|
+ target_sector_areas
|
||||||
|
+ "}\n"
|
||||||
|
)
|
||||||
s += """* Stay in the current area if the activity can be done there.
|
s += """* Stay in the current area if the activity can be done there.
|
||||||
* NEVER go into other people's rooms unless necessary."""
|
* NEVER go into other people's rooms unless necessary."""
|
||||||
s += (persona_name + " is " + current_action + ". For " + next_action +
|
s += (
|
||||||
", " + persona_name + "should go to the following area in " +
|
persona_name
|
||||||
target_sector)
|
+ " is "
|
||||||
|
+ current_action
|
||||||
|
+ ". For "
|
||||||
|
+ next_action
|
||||||
|
+ ", "
|
||||||
|
+ persona_name
|
||||||
|
+ "should go to the following area in "
|
||||||
|
+ target_sector
|
||||||
|
)
|
||||||
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
s += " (MUST pick one of {" + target_sector_areas + "}):\n"
|
||||||
s += "Answer: {"
|
s += "Answer: {"
|
||||||
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
return {"prompt": s, "max_tokens": 5, "stop": "}"}
|
||||||
|
|||||||
@@ -1,29 +1,29 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from agent_functions import (
|
||||||
|
action_location_object_prompt,
|
||||||
|
action_location_sector_prompt,
|
||||||
|
generate_event_triple_prompt,
|
||||||
|
generate_pronunciatio_prompt,
|
||||||
|
poignancy_event_prompt,
|
||||||
|
)
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_other_args_and_parse,
|
add_common_other_args_and_parse,
|
||||||
call_generate_lightllm,
|
call_generate_lightllm,
|
||||||
call_generate_vllm,
|
|
||||||
call_generate_srt_raw,
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
)
|
)
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
from agent_functions import (
|
|
||||||
poignancy_event_prompt,
|
|
||||||
generate_event_triple_prompt,
|
|
||||||
generate_pronunciatio_prompt,
|
|
||||||
action_location_sector_prompt,
|
|
||||||
action_location_object_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_events]
|
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||||
mapping = {
|
mapping = {
|
||||||
"poignancy_event": poignancy_event_prompt,
|
"poignancy_event": poignancy_event_prompt,
|
||||||
"generate_event_triple": generate_event_triple_prompt,
|
"generate_event_triple": generate_event_triple_prompt,
|
||||||
@@ -46,7 +46,7 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp(
|
model = models.LlamaCpp(
|
||||||
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
||||||
@@ -55,11 +55,15 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop):
|
def call_generate(prompt, temperature, max_tokens, stop):
|
||||||
out = model + prompt + gen(
|
out = (
|
||||||
name="result",
|
model
|
||||||
max_tokens=max_tokens,
|
+ prompt
|
||||||
temperature=temperature,
|
+ gen(
|
||||||
stop=stop,
|
name="result",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return out["result"]
|
return out["result"]
|
||||||
|
|
||||||
|
|||||||
@@ -2,24 +2,24 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from agent_functions import (
|
||||||
|
action_location_object,
|
||||||
|
action_location_sector,
|
||||||
|
generate_event_triple,
|
||||||
|
generate_pronunciatio,
|
||||||
|
poignancy_event,
|
||||||
|
)
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
from agent_functions import (
|
|
||||||
poignancy_event,
|
|
||||||
generate_event_triple,
|
|
||||||
generate_pronunciatio,
|
|
||||||
action_location_sector,
|
|
||||||
action_location_object,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_events]
|
lines = read_jsonl(args.data_path)[: args.num_events]
|
||||||
mapping = {
|
mapping = {
|
||||||
"poignancy_event": poignancy_event,
|
"poignancy_event": poignancy_event,
|
||||||
"generate_event_triple": generate_event_triple,
|
"generate_event_triple": generate_event_triple,
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
@@ -32,7 +37,7 @@ def get_few_shot_examples(lines, k):
|
|||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -50,7 +55,7 @@ def main(args):
|
|||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -68,19 +73,31 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop):
|
def call_generate(prompt, temperature, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
elif args.backend == "lmql":
|
elif args.backend == "lmql":
|
||||||
import lmql
|
import lmql
|
||||||
model = lmql.model(args.model_path,
|
|
||||||
endpoint=f"{args.host}:{args.port}")
|
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
|
||||||
|
|
||||||
@lmql.query(model=model)
|
@lmql.query(model=model)
|
||||||
async def program(question):
|
async def program(question):
|
||||||
@@ -103,7 +120,8 @@ def main(args):
|
|||||||
prompt=few_shot_examples + questions[i],
|
prompt=few_shot_examples + questions[i],
|
||||||
temperature=0,
|
temperature=0,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
stop="Question")
|
stop="Question",
|
||||||
|
)
|
||||||
states[i] = answer
|
states[i] = answer
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@@ -118,12 +136,18 @@ def main(args):
|
|||||||
async def batched_call(batch_size):
|
async def batched_call(batch_size):
|
||||||
for i in range(0, len(questions), batch_size):
|
for i in range(0, len(questions), batch_size):
|
||||||
tasks = []
|
tasks = []
|
||||||
for q in questions[i:i+batch_size]:
|
for q in questions[i : i + batch_size]:
|
||||||
tasks.append(call_generate(few_shot_examples + q,
|
tasks.append(
|
||||||
temperature=0, max_tokens=256, stop="Question"))
|
call_generate(
|
||||||
|
few_shot_examples + q,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop="Question",
|
||||||
|
)
|
||||||
|
)
|
||||||
rets = await asyncio.gather(*tasks)
|
rets = await asyncio.gather(*tasks)
|
||||||
for j in range(len(rets)):
|
for j in range(len(rets)):
|
||||||
states[i+j] = rets[j]
|
states[i + j] = rets[j]
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
asyncio.run(batched_call(batch_size=args.parallel))
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
@@ -154,7 +178,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -5,9 +5,12 @@ import re
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
@@ -28,7 +31,7 @@ def get_few_shot_examples(lines, k):
|
|||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -46,7 +49,7 @@ def main(args):
|
|||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -73,7 +76,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = few_shot_gsm8k.run_batch(
|
states = few_shot_gsm8k.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
preds = []
|
preds = []
|
||||||
@@ -101,7 +109,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import json
|
import json
|
||||||
from functools import partial
|
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_select_lightllm, call_select_vllm
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_select_lightllm,
|
||||||
|
call_select_vllm,
|
||||||
|
)
|
||||||
from sglang.utils import read_jsonl
|
from sglang.utils import read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def get_one_example(lines, i, include_answer):
|
def get_one_example(lines, i, include_answer):
|
||||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||||
if include_answer:
|
if include_answer:
|
||||||
ret += lines[i]["endings"][lines[i]["label"]]
|
ret += lines[i]["endings"][lines[i]["label"]]
|
||||||
return ret
|
return ret
|
||||||
@@ -34,7 +39,7 @@ def main(args):
|
|||||||
questions = []
|
questions = []
|
||||||
choices = []
|
choices = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
choices.append(lines[i]["endings"])
|
choices.append(lines[i]["endings"])
|
||||||
labels.append(lines[i]["label"])
|
labels.append(lines[i]["label"])
|
||||||
@@ -51,7 +56,11 @@ def main(args):
|
|||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, select
|
from guidance import models, select
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def call_select(context, choices):
|
def call_select(context, choices):
|
||||||
out = model + context + select(choices, name="answer")
|
out = model + context + select(choices, name="answer")
|
||||||
@@ -61,8 +70,10 @@ def main(args):
|
|||||||
|
|
||||||
elif args.backend == "lmql":
|
elif args.backend == "lmql":
|
||||||
import lmql
|
import lmql
|
||||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
endpoint=f"{args.host}:{args.port}")
|
model = lmql.model(
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
|
||||||
|
)
|
||||||
|
|
||||||
@lmql.query(model=model)
|
@lmql.query(model=model)
|
||||||
async def program(ctx, choices):
|
async def program(ctx, choices):
|
||||||
@@ -83,8 +94,8 @@ def main(args):
|
|||||||
# Use thread pool
|
# Use thread pool
|
||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
preds[i] = call_select(
|
preds[i] = call_select(
|
||||||
context=few_shot_examples + questions[i],
|
context=few_shot_examples + questions[i], choices=choices[i]
|
||||||
choices=choices[i])
|
)
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
if args.parallel == 1:
|
if args.parallel == 1:
|
||||||
@@ -98,13 +109,13 @@ def main(args):
|
|||||||
async def batched_call(batch_size):
|
async def batched_call(batch_size):
|
||||||
for i in range(0, len(questions), batch_size):
|
for i in range(0, len(questions), batch_size):
|
||||||
tasks = []
|
tasks = []
|
||||||
for q, c in zip(questions[i:i+batch_size], choices[i:i+batch_size]):
|
for q, c in zip(
|
||||||
tasks.append(call_select(
|
questions[i : i + batch_size], choices[i : i + batch_size]
|
||||||
context=few_shot_examples + q,
|
):
|
||||||
choices=c))
|
tasks.append(call_select(context=few_shot_examples + q, choices=c))
|
||||||
rets = await asyncio.gather(*tasks)
|
rets = await asyncio.gather(*tasks)
|
||||||
for j in range(len(rets)):
|
for j in range(len(rets)):
|
||||||
preds[i+j] = rets[j]
|
preds[i + j] = rets[j]
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
asyncio.run(batched_call(batch_size=args.parallel))
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
@@ -128,7 +139,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -3,12 +3,16 @@ import json
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
from sglang.utils import read_jsonl
|
from sglang.utils import read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def get_one_example(lines, i, include_answer):
|
def get_one_example(lines, i, include_answer):
|
||||||
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " "
|
||||||
if include_answer:
|
if include_answer:
|
||||||
ret += lines[i]["endings"][lines[i]["label"]]
|
ret += lines[i]["endings"][lines[i]["label"]]
|
||||||
return ret
|
return ret
|
||||||
@@ -31,14 +35,11 @@ def main(args):
|
|||||||
questions = []
|
questions = []
|
||||||
choices = []
|
choices = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(get_one_example(lines, i, False))
|
questions.append(get_one_example(lines, i, False))
|
||||||
choices.append(lines[i]["endings"])
|
choices.append(lines[i]["endings"])
|
||||||
labels.append(lines[i]["label"])
|
labels.append(lines[i]["label"])
|
||||||
arguments = [
|
arguments = [{"question": q, "choices": c} for q, c in zip(questions, choices)]
|
||||||
{"question": q, "choices": c}
|
|
||||||
for q, c in zip(questions, choices)
|
|
||||||
]
|
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
######### SGL Program Begin #########
|
######### SGL Program Begin #########
|
||||||
@@ -61,7 +62,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
rets = few_shot_hellaswag.run_batch(
|
rets = few_shot_hellaswag.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
preds = [choices[i].index(rets[i]["answer"]) for i in range(len(rets))]
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
@@ -82,7 +88,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -4,13 +4,14 @@ import time
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_other_args_and_parse,
|
add_common_other_args_and_parse,
|
||||||
call_generate_outlines,
|
call_generate_outlines,
|
||||||
)
|
)
|
||||||
from sglang.utils import dump_state_text, read_jsonl
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
REGEX_LIST = r"\[(" + REGEX_STRING + ", )*" + REGEX_STRING + r"\]"
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import json
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.lang.ir import REGEX_INT, REGEX_STRING, REGEX_FLOAT
|
from sglang.lang.ir import REGEX_FLOAT, REGEX_INT, REGEX_STRING
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
@@ -63,7 +63,9 @@ def main(args):
|
|||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = json_decode.run_batch(arguments, temperature=0, num_threads=args.parallel, progress_bar=True)
|
states = json_decode.run_batch(
|
||||||
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import guidance
|
import guidance
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_other_args_and_parse,
|
add_common_other_args_and_parse,
|
||||||
call_generate_outlines,
|
call_generate_outlines,
|
||||||
)
|
)
|
||||||
from sglang.utils import dump_state_text, read_jsonl
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# there are some FSM bugs with json regex converted from pydantic model
|
# there are some FSM bugs with json regex converted from pydantic model
|
||||||
# here use a string regex instead
|
# here use a string regex instead
|
||||||
|
|||||||
@@ -15,16 +15,17 @@ On the client side, run:
|
|||||||
--tokenizer <your_model> --dataset <target_dataset> \
|
--tokenizer <your_model> --dataset <target_dataset> \
|
||||||
--request-rate <request_rate>
|
--request-rate <request_rate>
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
from typing import AsyncGenerator, List, Tuple
|
from typing import AsyncGenerator, List, Tuple
|
||||||
from tqdm.asyncio import tqdm_asyncio
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm.asyncio import tqdm_asyncio
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
@@ -41,10 +42,7 @@ def sample_requests(
|
|||||||
with open(dataset_path) as f:
|
with open(dataset_path) as f:
|
||||||
dataset = json.load(f)
|
dataset = json.load(f)
|
||||||
# Filter out the conversations with less than 2 turns.
|
# Filter out the conversations with less than 2 turns.
|
||||||
dataset = [
|
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||||
data for data in dataset
|
|
||||||
if len(data["conversations"]) >= 2
|
|
||||||
]
|
|
||||||
# Only keep the first two turns of each conversation.
|
# Only keep the first two turns of each conversation.
|
||||||
dataset = [
|
dataset = [
|
||||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||||
@@ -185,9 +183,17 @@ async def benchmark(
|
|||||||
tasks: List[asyncio.Task] = []
|
tasks: List[asyncio.Task] = []
|
||||||
async for request in get_request(input_requests, request_rate):
|
async for request in get_request(input_requests, request_rate):
|
||||||
prompt, prompt_len, output_len = request
|
prompt, prompt_len, output_len = request
|
||||||
task = asyncio.create_task(send_request(backend, api_url, prompt,
|
task = asyncio.create_task(
|
||||||
prompt_len, output_len,
|
send_request(
|
||||||
best_of, use_beam_search))
|
backend,
|
||||||
|
api_url,
|
||||||
|
prompt,
|
||||||
|
prompt_len,
|
||||||
|
output_len,
|
||||||
|
best_of,
|
||||||
|
use_beam_search,
|
||||||
|
)
|
||||||
|
)
|
||||||
tasks.append(task)
|
tasks.append(task)
|
||||||
await tqdm_asyncio.gather(*tasks)
|
await tqdm_asyncio.gather(*tasks)
|
||||||
|
|
||||||
@@ -202,8 +208,16 @@ def main(args: argparse.Namespace):
|
|||||||
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||||
|
|
||||||
benchmark_start_time = time.perf_counter()
|
benchmark_start_time = time.perf_counter()
|
||||||
asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
|
asyncio.run(
|
||||||
args.use_beam_search, args.request_rate))
|
benchmark(
|
||||||
|
args.backend,
|
||||||
|
api_url,
|
||||||
|
input_requests,
|
||||||
|
args.best_of,
|
||||||
|
args.use_beam_search,
|
||||||
|
args.request_rate,
|
||||||
|
)
|
||||||
|
)
|
||||||
benchmark_end_time = time.perf_counter()
|
benchmark_end_time = time.perf_counter()
|
||||||
benchmark_time = benchmark_end_time - benchmark_start_time
|
benchmark_time = benchmark_end_time - benchmark_start_time
|
||||||
print(f"Total time: {benchmark_time:.2f} s")
|
print(f"Total time: {benchmark_time:.2f} s")
|
||||||
@@ -212,43 +226,61 @@ def main(args: argparse.Namespace):
|
|||||||
# Compute the latency statistics.
|
# Compute the latency statistics.
|
||||||
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
|
||||||
print(f"Average latency: {avg_latency:.2f} s")
|
print(f"Average latency: {avg_latency:.2f} s")
|
||||||
avg_per_token_latency = np.mean([
|
avg_per_token_latency = np.mean(
|
||||||
latency / (prompt_len + output_len)
|
[
|
||||||
for prompt_len, output_len, latency in REQUEST_LATENCY
|
latency / (prompt_len + output_len)
|
||||||
])
|
for prompt_len, output_len, latency in REQUEST_LATENCY
|
||||||
|
]
|
||||||
|
)
|
||||||
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
print(f"Average latency per token: {avg_per_token_latency:.2f} s")
|
||||||
avg_per_output_token_latency = np.mean([
|
avg_per_output_token_latency = np.mean(
|
||||||
latency / output_len
|
[latency / output_len for _, output_len, latency in REQUEST_LATENCY]
|
||||||
for _, output_len, latency in REQUEST_LATENCY
|
)
|
||||||
])
|
print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s")
|
||||||
print("Average latency per output token: "
|
|
||||||
f"{avg_per_output_token_latency:.2f} s")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Benchmark the online serving throughput.")
|
description="Benchmark the online serving throughput."
|
||||||
parser.add_argument("--backend", type=str, default="vllm",
|
)
|
||||||
choices=["vllm", "tgi", "srt", "lightllm"])
|
parser.add_argument(
|
||||||
|
"--backend",
|
||||||
|
type=str,
|
||||||
|
default="vllm",
|
||||||
|
choices=["vllm", "tgi", "srt", "lightllm"],
|
||||||
|
)
|
||||||
parser.add_argument("--host", type=str, default="localhost")
|
parser.add_argument("--host", type=str, default="localhost")
|
||||||
parser.add_argument("--port", type=int, default=8000)
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
parser.add_argument("--dataset", type=str, required=True,
|
parser.add_argument(
|
||||||
help="Path to the dataset.")
|
"--dataset", type=str, required=True, help="Path to the dataset."
|
||||||
parser.add_argument("--tokenizer", type=str, required=True,
|
)
|
||||||
help="Name or path of the tokenizer.")
|
parser.add_argument(
|
||||||
parser.add_argument("--best-of", type=int, default=1,
|
"--tokenizer", type=str, required=True, help="Name or path of the tokenizer."
|
||||||
help="Generates `best_of` sequences per prompt and "
|
)
|
||||||
"returns the best one.")
|
parser.add_argument(
|
||||||
|
"--best-of",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Generates `best_of` sequences per prompt and " "returns the best one.",
|
||||||
|
)
|
||||||
parser.add_argument("--use-beam-search", action="store_true")
|
parser.add_argument("--use-beam-search", action="store_true")
|
||||||
parser.add_argument("--num-prompts", type=int, default=1000,
|
parser.add_argument(
|
||||||
help="Number of prompts to process.")
|
"--num-prompts", type=int, default=1000, help="Number of prompts to process."
|
||||||
parser.add_argument("--request-rate", type=float, default=float("inf"),
|
)
|
||||||
help="Number of requests per second. If this is inf, "
|
parser.add_argument(
|
||||||
"then all the requests are sent at time 0. "
|
"--request-rate",
|
||||||
"Otherwise, we use Poisson process to synthesize "
|
type=float,
|
||||||
"the request arrival times.")
|
default=float("inf"),
|
||||||
|
help="Number of requests per second. If this is inf, "
|
||||||
|
"then all the requests are sent at time 0. "
|
||||||
|
"Otherwise, we use Poisson process to synthesize "
|
||||||
|
"the request arrival times.",
|
||||||
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0)
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
parser.add_argument('--trust-remote-code', action='store_true',
|
parser.add_argument(
|
||||||
help='trust remote code from huggingface')
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="trust remote code from huggingface",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args)
|
main(args)
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
from sglang.utils import dump_state_text
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
|
||||||
@@ -35,23 +39,30 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
|||||||
dst_percent = dst_percents[j]
|
dst_percent = dst_percents[j]
|
||||||
|
|
||||||
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
query_indices = line_obj["group_by_num_hoops"][str(num_hoops)]
|
||||||
query_indices = [q for q in query_indices if
|
query_indices = [
|
||||||
all(l <= src_index for l in line_obj["links"][q]) and q < src_index]
|
q
|
||||||
dst_index = query_indices[min(int(len(query_indices) * dst_percent), len(query_indices)-1)]
|
for q in query_indices
|
||||||
|
if all(l <= src_index for l in line_obj["links"][q]) and q < src_index
|
||||||
|
]
|
||||||
|
dst_index = query_indices[
|
||||||
|
min(int(len(query_indices) * dst_percent), len(query_indices) - 1)
|
||||||
|
]
|
||||||
label = line_obj["values"][dst_index]
|
label = line_obj["values"][dst_index]
|
||||||
|
|
||||||
body = line_obj["lines"][:src_index+1]
|
body = line_obj["lines"][: src_index + 1]
|
||||||
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
suffix = line_obj["suffix"].replace("???", line_obj["indices"][dst_index])
|
||||||
body_part_len = len(body) // 4
|
body_part_len = len(body) // 4
|
||||||
|
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"prefix": line_obj["prefix"],
|
{
|
||||||
"body_0": "\n".join(body[:body_part_len]),
|
"prefix": line_obj["prefix"],
|
||||||
"body_1": "\n".join(body[body_part_len: 2 * body_part_len]),
|
"body_0": "\n".join(body[:body_part_len]),
|
||||||
"body_2": "\n".join(body[2 * body_part_len: 3 * body_part_len]),
|
"body_1": "\n".join(body[body_part_len : 2 * body_part_len]),
|
||||||
"body_3": "\n".join(body[3 * body_part_len:]),
|
"body_2": "\n".join(body[2 * body_part_len : 3 * body_part_len]),
|
||||||
"suffix": suffix,
|
"body_3": "\n".join(body[3 * body_part_len :]),
|
||||||
})
|
"suffix": suffix,
|
||||||
|
}
|
||||||
|
)
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
sum_src_indices.append(src_index)
|
sum_src_indices.append(src_index)
|
||||||
sum_dst_indices.append(dst_index)
|
sum_dst_indices.append(dst_index)
|
||||||
@@ -61,7 +72,12 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
|||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = line_retrieval.run_batch(
|
states = line_retrieval.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
corrects = []
|
corrects = []
|
||||||
@@ -79,7 +95,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
|||||||
if response_number == label:
|
if response_number == label:
|
||||||
break
|
break
|
||||||
|
|
||||||
correct = (response_number == label)
|
correct = response_number == label
|
||||||
corrects.append(correct)
|
corrects.append(correct)
|
||||||
|
|
||||||
# Log results
|
# Log results
|
||||||
@@ -107,7 +123,7 @@ def eval_model(args, line_obj, num_hoops, src_indices, dst_percents):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": len(arguments),
|
"num_questions": len(arguments),
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,13 @@ Generate line data for line retrieval task.
|
|||||||
Usage:
|
Usage:
|
||||||
python3 gen_data.py --number 1000
|
python3 gen_data.py --number 1000
|
||||||
"""
|
"""
|
||||||
import argparse
|
|
||||||
from collections import defaultdict
|
|
||||||
import json
|
|
||||||
|
|
||||||
from tqdm import tqdm
|
import argparse
|
||||||
|
import json
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
def generate_lines(random_words, num_lines, redirect_ratio):
|
def generate_lines(random_words, num_lines, redirect_ratio):
|
||||||
@@ -42,11 +43,14 @@ def generate_lines(random_words, num_lines, redirect_ratio):
|
|||||||
# Add redirect
|
# Add redirect
|
||||||
if redirect_ratio > 0:
|
if redirect_ratio > 0:
|
||||||
num_redirect_lines = int(len(lines) * redirect_ratio)
|
num_redirect_lines = int(len(lines) * redirect_ratio)
|
||||||
redirect_indices = np.random.choice(np.arange(len(lines)),
|
redirect_indices = np.random.choice(
|
||||||
size=(num_redirect_lines,), replace=False)
|
np.arange(len(lines)), size=(num_redirect_lines,), replace=False
|
||||||
|
)
|
||||||
for i in redirect_indices:
|
for i in redirect_indices:
|
||||||
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
|
||||||
lines[i] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
lines[i] = (
|
||||||
|
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
|
||||||
|
)
|
||||||
redirects[i] = target_idx
|
redirects[i] = target_idx
|
||||||
|
|
||||||
# Build links and find sources
|
# Build links and find sources
|
||||||
|
|||||||
@@ -1,13 +1,16 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import time
|
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
import tqdm
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
import tqdm
|
from sglang.test.test_utils import (
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
add_common_sglang_args_and_parse,
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
select_sglang_backend,
|
||||||
from PIL import Image
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
@@ -17,17 +20,19 @@ def image_qa(s, image_file, question):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.question_file)[:args.num_questions]
|
lines = read_jsonl(args.question_file)[: args.num_questions]
|
||||||
arguments = [
|
arguments = [
|
||||||
{"image_file":
|
{
|
||||||
os.path.abspath(args.image_folder + "/" + l["image"]),
|
"image_file": os.path.abspath(args.image_folder + "/" + l["image"]),
|
||||||
"question": l["text"]} for l in lines
|
"question": l["text"],
|
||||||
|
}
|
||||||
|
for l in lines
|
||||||
]
|
]
|
||||||
#arguments = [
|
# arguments = [
|
||||||
# {"image_file":
|
# {"image_file":
|
||||||
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
|
# Image.open(os.path.abspath(args.image_folder + "/" + l["image"])),
|
||||||
# "question": l["text"]} for l in lines
|
# "question": l["text"]} for l in lines
|
||||||
#]
|
# ]
|
||||||
|
|
||||||
states = [None] * len(lines)
|
states = [None] * len(lines)
|
||||||
|
|
||||||
@@ -41,17 +46,12 @@ def main(args):
|
|||||||
for i in tqdm.tqdm(range(len(lines))):
|
for i in tqdm.tqdm(range(len(lines))):
|
||||||
image_file = arguments[i]["image_file"]
|
image_file = arguments[i]["image_file"]
|
||||||
question = arguments[i]["question"]
|
question = arguments[i]["question"]
|
||||||
ret = image_qa.run(
|
ret = image_qa.run(image_file=image_file, question=question, temperature=0)
|
||||||
image_file=image_file,
|
|
||||||
question=question,
|
|
||||||
temperature=0)
|
|
||||||
states[i] = ret
|
states[i] = ret
|
||||||
else:
|
else:
|
||||||
states = image_qa.run_batch(
|
states = image_qa.run_batch(
|
||||||
arguments,
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
temperature=0,
|
)
|
||||||
num_threads=args.parallel,
|
|
||||||
progress_bar=True)
|
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
print(f"Latency: {latency:.3f}")
|
print(f"Latency: {latency:.3f}")
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
# Create the 'images' directory if it doesn't exist
|
# Create the 'images' directory if it doesn't exist
|
||||||
if not os.path.exists('images'):
|
if not os.path.exists("images"):
|
||||||
os.makedirs('images')
|
os.makedirs("images")
|
||||||
|
|
||||||
# Base URL
|
# Base URL
|
||||||
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
|
base_url = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/"
|
||||||
|
|||||||
@@ -1,27 +1,28 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
system_prompt = (
|
add_common_other_args_and_parse,
|
||||||
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
)
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||||
|
|
||||||
dimension_prompts = [
|
dimension_prompts = [
|
||||||
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||||
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||||
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||||
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||||
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||||
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -31,12 +32,16 @@ def multi_dimension_judge(article, generate):
|
|||||||
|
|
||||||
judges = []
|
judges = []
|
||||||
for i in range(len(dimension_prompts)):
|
for i in range(len(dimension_prompts)):
|
||||||
comp = generate(s +
|
comp = generate(
|
||||||
"USER: Please judge the quality based on the following metric. " +
|
s
|
||||||
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
|
+ "USER: Please judge the quality based on the following metric. "
|
||||||
"Focus on the provided metric and do not say other things. "
|
+ dimension_prompts[i]
|
||||||
'End your judgement paragraph with the word "END"\nJUDGE:',
|
+ " Please provide a single-paragraph judgement. "
|
||||||
max_tokens=256, stop="END")
|
+ "Focus on the provided metric and do not say other things. "
|
||||||
|
'End your judgement paragraph with the word "END"\nJUDGE:',
|
||||||
|
max_tokens=256,
|
||||||
|
stop="END",
|
||||||
|
)
|
||||||
judges.append(comp)
|
judges.append(comp)
|
||||||
|
|
||||||
s += "I will judge the quality based on the following metrics.\n"
|
s += "I will judge the quality based on the following metrics.\n"
|
||||||
@@ -50,7 +55,7 @@ def multi_dimension_judge(article, generate):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
states = [None] * len(lines)
|
states = [None] * len(lines)
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -64,13 +69,20 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def generate(prompt, max_tokens, stop):
|
def generate(prompt, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -107,7 +119,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -2,23 +2,22 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
|
||||||
system_prompt = (
|
|
||||||
"Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
|
||||||
)
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
system_prompt = "Please serve as an impartial judge and rigorously evaluate the quality of the following article. Apply the most stringent standards possible, showing no leniency."
|
||||||
|
|
||||||
dimension_prompts = [
|
dimension_prompts = [
|
||||||
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
"Content: This refers to the essences of the essay. The substance should be well researched, accurate, relevant to the topic and should show a thorough understanding of the subject. The essay should also reflect a clear goal or purpose.",
|
||||||
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
"Organization and Structure: An essay needs to be properly structured with a clear introduction, body, and conclusion. The essay should flow naturally, with one paragraph leading seamlessly into the next.",
|
||||||
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
"Argument and Analysis: The argument made in the essay should be logical, coherent and clearly articulated. Each point made should be backed up by solid evidence and thorough analysis.",
|
||||||
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
"Clarity and Precision: The essay should be written in a clear and concise manner. The points made should be easily understood by the reader. The language used should also be precise and unambiguous.",
|
||||||
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
"Grammar and Punctuation: Proper use of grammar and punctuation is vital in an academic essay. Errors in grammar and punctuation not only distract the reader but can also negatively impact the meaning and interpretation of the content.",
|
||||||
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
"Referencing and Citation: An essay should contain proper citations and references for all sources used. This not only prevents accusations of plagiarism but also gives credit to the authors of the works that have contributed to the essay. The citation should adhere to a specific format as required by the academic institution or specified by the professor.",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -29,23 +28,31 @@ def multi_dimension_judge(s, article):
|
|||||||
|
|
||||||
forks = s.fork(len(dimension_prompts))
|
forks = s.fork(len(dimension_prompts))
|
||||||
for i in range(len(dimension_prompts)):
|
for i in range(len(dimension_prompts)):
|
||||||
forks[i] += ("USER: Please judge the quality based on the following metric. " +
|
forks[i] += (
|
||||||
dimension_prompts[i] + " Please provide a single-paragraph judgement. " +
|
"USER: Please judge the quality based on the following metric. "
|
||||||
"Focus on the provided metric and do not say other things. "
|
+ dimension_prompts[i]
|
||||||
'End your judgement paragraph with the word "END"\nJUDGE:')
|
+ " Please provide a single-paragraph judgement. "
|
||||||
|
+ "Focus on the provided metric and do not say other things. "
|
||||||
|
'End your judgement paragraph with the word "END"\nJUDGE:'
|
||||||
|
)
|
||||||
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
|
forks[i] += sgl.gen("judgement", max_tokens=256, stop="END")
|
||||||
forks.join()
|
forks.join()
|
||||||
|
|
||||||
s += "I will judge the quality based on the following metrics.\n"
|
s += "I will judge the quality based on the following metrics.\n"
|
||||||
for i in range(len(dimension_prompts)):
|
for i in range(len(dimension_prompts)):
|
||||||
s += dimension_prompts[i].split(":")[0] + ": " + forks[i]["judgement"].strip() + "\n"
|
s += (
|
||||||
|
dimension_prompts[i].split(":")[0]
|
||||||
|
+ ": "
|
||||||
|
+ forks[i]["judgement"].strip()
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
s += "In summary, on a scale of 1 to 10, I would give the article a score of"
|
||||||
s += sgl.gen("score", max_tokens=2)
|
s += sgl.gen("score", max_tokens=2)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
arguments = [{"article": l} for l in lines]
|
arguments = [{"article": l} for l in lines]
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -54,7 +61,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = multi_dimension_judge.run_batch(
|
states = multi_dimension_judge.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
print(f"Latency: {latency:.3f}")
|
print(f"Latency: {latency:.3f}")
|
||||||
@@ -72,7 +84,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,21 +1,25 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
from sglang.test.test_utils import (
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def json_decode(document, generate):
|
def json_decode(document, generate):
|
||||||
s = "Please extract the information of a city from the following wikipedia page.\n"
|
s = "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
s += "Page begin.\n" + document + "Page end.\n"
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
s += '{\n'
|
s += "{\n"
|
||||||
s += ' "name": "'
|
s += ' "name": "'
|
||||||
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||||
s += ' "country": "'
|
s += ' "country": "'
|
||||||
@@ -24,17 +28,19 @@ def json_decode(document, generate):
|
|||||||
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
s += generate(s, max_tokens=8, stop='"') + '",\n'
|
||||||
s += ' "top 3 landmarks": "'
|
s += ' "top 3 landmarks": "'
|
||||||
s += generate(s, max_tokens=24, stop='"') + '",\n'
|
s += generate(s, max_tokens=24, stop='"') + '",\n'
|
||||||
s += '}\n'
|
s += "}\n"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
lines = read_jsonl(args.data_path)
|
||||||
arguments = []
|
arguments = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"document": lines[i]["document"],
|
{
|
||||||
})
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
states = [None] * len(arguments)
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -48,13 +54,20 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=11000,
|
||||||
|
)
|
||||||
|
|
||||||
def generate(prompt, max_tokens, stop):
|
def generate(prompt, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -91,7 +104,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
@@ -13,21 +15,31 @@ def json_decode(s, document):
|
|||||||
s += "Please extract the information of a city from the following wikipedia page.\n"
|
s += "Please extract the information of a city from the following wikipedia page.\n"
|
||||||
s += "Page begin.\n" + document + "Page end.\n"
|
s += "Page begin.\n" + document + "Page end.\n"
|
||||||
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
s += "Here is the name, country, and symbol of the city in JSON format.\n"
|
||||||
s += '{\n'
|
s += "{\n"
|
||||||
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
|
s += ' "name": "' + sgl.gen("name", max_tokens=8, stop='"') + '",\n'
|
||||||
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
|
s += ' "country": "' + sgl.gen("country", max_tokens=8, stop='"') + '",\n'
|
||||||
s += ' "air port code": "' + sgl.gen("air port code", max_tokens=8, stop='"') + '",\n'
|
s += (
|
||||||
s += ' "top 3 landmarks": "' + sgl.gen("landmarks", max_tokens=24, stop='"') + '",\n'
|
' "air port code": "'
|
||||||
s += '}\n'
|
+ sgl.gen("air port code", max_tokens=8, stop='"')
|
||||||
|
+ '",\n'
|
||||||
|
)
|
||||||
|
s += (
|
||||||
|
' "top 3 landmarks": "'
|
||||||
|
+ sgl.gen("landmarks", max_tokens=24, stop='"')
|
||||||
|
+ '",\n'
|
||||||
|
)
|
||||||
|
s += "}\n"
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
lines = read_jsonl(args.data_path)
|
||||||
arguments = []
|
arguments = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"document": lines[i]["document"],
|
{
|
||||||
})
|
"document": lines[i]["document"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
backend = select_sglang_backend(args)
|
backend = select_sglang_backend(args)
|
||||||
@@ -36,7 +48,8 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = json_decode.run_batch(
|
states = json_decode.run_batch(
|
||||||
arguments, temperature=0, num_threads=args.parallel, progress_bar=True)
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
@@ -55,7 +68,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import json
|
|||||||
import transformers
|
import transformers
|
||||||
import wikipedia
|
import wikipedia
|
||||||
|
|
||||||
|
|
||||||
name = "meta-llama/Llama-2-7b-chat-hf"
|
name = "meta-llama/Llama-2-7b-chat-hf"
|
||||||
t = transformers.AutoTokenizer.from_pretrained(name)
|
t = transformers.AutoTokenizer.from_pretrained(name)
|
||||||
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
|
city_names = ["los angles", "london", "tokyo", "beijing", "singapore"]
|
||||||
@@ -20,7 +19,9 @@ for city_name in city_names:
|
|||||||
truncate_tokens = t.encode(truncate_content)
|
truncate_tokens = t.encode(truncate_content)
|
||||||
|
|
||||||
# Count token
|
# Count token
|
||||||
print(f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}")
|
print(
|
||||||
|
f"city_name: {city_name}, #tokens: {len(tokens)}, #truncate tokens: {len(truncate_tokens)}"
|
||||||
|
)
|
||||||
|
|
||||||
with open("questions.jsonl", "a") as fout:
|
with open("questions.jsonl", "a") as fout:
|
||||||
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
fout.write(json.dumps({"document": truncate_content}) + "\n")
|
||||||
|
|||||||
@@ -1,17 +1,22 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
import json
|
import json
|
||||||
from functools import partial
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
|
||||||
choices = ["A", "B", "C", "D"]
|
choices = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
@@ -25,18 +30,22 @@ def format_subject(subject):
|
|||||||
s += " " + entry
|
s += " " + entry
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def format_example(df, idx, include_answer=True):
|
def format_example(df, idx, include_answer=True):
|
||||||
prompt = df.iloc[idx, 0]
|
prompt = df.iloc[idx, 0]
|
||||||
k = df.shape[1] - 2
|
k = df.shape[1] - 2
|
||||||
for j in range(k):
|
for j in range(k):
|
||||||
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
|
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
||||||
prompt += "\nAnswer:"
|
prompt += "\nAnswer:"
|
||||||
if include_answer:
|
if include_answer:
|
||||||
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def gen_prompt(train_df, subject, k=-1):
|
def gen_prompt(train_df, subject, k=-1):
|
||||||
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
|
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
||||||
|
format_subject(subject)
|
||||||
|
)
|
||||||
if k == -1:
|
if k == -1:
|
||||||
k = train_df.shape[0]
|
k = train_df.shape[0]
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
@@ -63,7 +72,7 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
prompt = train_prompt + prompt_end
|
prompt = train_prompt + prompt_end
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|
||||||
label = test_df.iloc[i, test_df.shape[1]-1]
|
label = test_df.iloc[i, test_df.shape[1] - 1]
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
|
|
||||||
preds = [None] * len(prompts)
|
preds = [None] * len(prompts)
|
||||||
@@ -82,17 +91,24 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url, stop=None)
|
call_generate = partial(call_generate_srt_raw, url=url, stop=None)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
if model_initialized is None:
|
if model_initialized is None:
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
model_initialized = model
|
model_initialized = model
|
||||||
else:
|
else:
|
||||||
model = model_initialized
|
model = model_initialized
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens):
|
def call_generate(prompt, temperature, max_tokens):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=0)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -100,8 +116,10 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
|
|
||||||
elif args.backend == "lmql":
|
elif args.backend == "lmql":
|
||||||
import lmql
|
import lmql
|
||||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
endpoint=f"{args.host}:{args.port}")
|
model = lmql.model(
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
|
||||||
|
)
|
||||||
|
|
||||||
@lmql.query(model=model)
|
@lmql.query(model=model)
|
||||||
async def program(question):
|
async def program(question):
|
||||||
@@ -112,6 +130,7 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
|
|
||||||
async def call_generate(prompt, temperature, max_tokens):
|
async def call_generate(prompt, temperature, max_tokens):
|
||||||
return await program(question=prompt, temperature=temperature)
|
return await program(question=prompt, temperature=temperature)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Invalid backend: {args.backend}")
|
raise ValueError(f"Invalid backend: {args.backend}")
|
||||||
|
|
||||||
@@ -119,8 +138,7 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
if args.backend != "lmql":
|
if args.backend != "lmql":
|
||||||
# Use thread pool
|
# Use thread pool
|
||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
pred = call_generate(prompts[i], temperature=0,
|
pred = call_generate(prompts[i], temperature=0, max_tokens=max_tokens)
|
||||||
max_tokens=max_tokens)
|
|
||||||
preds[i] = pred.strip()[0]
|
preds[i] = pred.strip()[0]
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@@ -135,12 +153,11 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
async def batched_call(batch_size):
|
async def batched_call(batch_size):
|
||||||
for i in range(0, len(prompts), batch_size):
|
for i in range(0, len(prompts), batch_size):
|
||||||
tasks = []
|
tasks = []
|
||||||
for p in prompts[i:i+batch_size]:
|
for p in prompts[i : i + batch_size]:
|
||||||
tasks.append(call_generate(p,
|
tasks.append(call_generate(p, temperature=0, max_tokens=max_tokens))
|
||||||
temperature=0, max_tokens=max_tokens))
|
|
||||||
rets = await asyncio.gather(*tasks)
|
rets = await asyncio.gather(*tasks)
|
||||||
for j in range(len(rets)):
|
for j in range(len(rets)):
|
||||||
preds[i+j] = rets[j].strip()[0]
|
preds[i + j] = rets[j].strip()[0]
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
asyncio.run(batched_call(batch_size=args.parallel))
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
@@ -151,22 +168,35 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
acc = np.mean(cors)
|
acc = np.mean(cors)
|
||||||
cors = np.array(cors)
|
cors = np.array(cors)
|
||||||
|
|
||||||
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
print(
|
||||||
acc, latency, len(prompts), subject))
|
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
||||||
|
acc, latency, len(prompts), subject
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return cors, acc, latency
|
return cors, acc, latency
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
|
subjects = sorted(
|
||||||
|
[
|
||||||
|
f.split("_test.csv")[0]
|
||||||
|
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
||||||
|
if "_test.csv" in f
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
all_cors = []
|
all_cors = []
|
||||||
all_latencies = []
|
all_latencies = []
|
||||||
num_requests = 0
|
num_requests = 0
|
||||||
|
|
||||||
for subject in tqdm(subjects[:args.nsub]):
|
for subject in tqdm(subjects[: args.nsub]):
|
||||||
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
|
dev_df = pd.read_csv(
|
||||||
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
|
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
||||||
|
)[: args.ntrain]
|
||||||
|
test_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
||||||
|
)
|
||||||
|
|
||||||
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
||||||
all_cors.append(cors)
|
all_cors.append(cors)
|
||||||
@@ -191,7 +221,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"nsub": args.nsub,
|
"nsub": args.nsub,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
|
||||||
choices = ["A", "B", "C", "D"]
|
choices = ["A", "B", "C", "D"]
|
||||||
|
|
||||||
@@ -22,24 +25,29 @@ def format_subject(subject):
|
|||||||
s += " " + entry
|
s += " " + entry
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def format_example(df, idx, include_answer=True):
|
def format_example(df, idx, include_answer=True):
|
||||||
prompt = df.iloc[idx, 0]
|
prompt = df.iloc[idx, 0]
|
||||||
k = df.shape[1] - 2
|
k = df.shape[1] - 2
|
||||||
for j in range(k):
|
for j in range(k):
|
||||||
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j+1])
|
prompt += "\n{}. {}".format(choices[j], df.iloc[idx, j + 1])
|
||||||
prompt += "\nAnswer:"
|
prompt += "\nAnswer:"
|
||||||
if include_answer:
|
if include_answer:
|
||||||
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
prompt += " {}\n\n".format(df.iloc[idx, k + 1])
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def gen_prompt(train_df, subject, k=-1):
|
def gen_prompt(train_df, subject, k=-1):
|
||||||
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(format_subject(subject))
|
prompt = "The following are multiple choice questions (with answers) about{}.\n\n".format(
|
||||||
|
format_subject(subject)
|
||||||
|
)
|
||||||
if k == -1:
|
if k == -1:
|
||||||
k = train_df.shape[0]
|
k = train_df.shape[0]
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
prompt += format_example(train_df, i)
|
prompt += format_example(train_df, i)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def evaluate(args, subject, dev_df, test_df):
|
def evaluate(args, subject, dev_df, test_df):
|
||||||
prompts = []
|
prompts = []
|
||||||
labels = []
|
labels = []
|
||||||
@@ -54,7 +62,7 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
prompt_end = format_example(test_df, i, include_answer=False)
|
prompt_end = format_example(test_df, i, include_answer=False)
|
||||||
prompts.append(prompt_end)
|
prompts.append(prompt_end)
|
||||||
|
|
||||||
label = test_df.iloc[i, test_df.shape[1]-1]
|
label = test_df.iloc[i, test_df.shape[1] - 1]
|
||||||
labels.append(label)
|
labels.append(label)
|
||||||
|
|
||||||
arguments = [{"question": p} for p in prompts]
|
arguments = [{"question": p} for p in prompts]
|
||||||
@@ -66,11 +74,14 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
|
|
||||||
if args.backend.startswith("gpt-"):
|
if args.backend.startswith("gpt-"):
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def few_shot_mmlu(s, examples, question):
|
def few_shot_mmlu(s, examples, question):
|
||||||
s += sgl.user(examples + question)
|
s += sgl.user(examples + question)
|
||||||
s += sgl.assistant(sgl.gen("answer"))
|
s += sgl.assistant(sgl.gen("answer"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def few_shot_mmlu(s, examples, question):
|
def few_shot_mmlu(s, examples, question):
|
||||||
s += examples + question + sgl.gen("answer")
|
s += examples + question + sgl.gen("answer")
|
||||||
@@ -84,32 +95,50 @@ def evaluate(args, subject, dev_df, test_df):
|
|||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
|
states = few_shot_mmlu.bind(examples=few_shot_examples).run_batch(
|
||||||
arguments, temperature=0, max_new_tokens=1,
|
arguments,
|
||||||
backend=backend, num_threads=args.parallel)
|
temperature=0,
|
||||||
preds = [s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else ""
|
max_new_tokens=1,
|
||||||
for s in states]
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
)
|
||||||
|
preds = [
|
||||||
|
s["answer"].strip()[0] if len(s["answer"].strip()) > 0 else "" for s in states
|
||||||
|
]
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
cors = [pred == label for pred, label in zip(preds, labels)]
|
cors = [pred == label for pred, label in zip(preds, labels)]
|
||||||
acc = np.mean(cors)
|
acc = np.mean(cors)
|
||||||
cors = np.array(cors)
|
cors = np.array(cors)
|
||||||
|
|
||||||
print("Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
print(
|
||||||
acc, latency, len(prompts), subject))
|
"Average accuracy {:.3f}, latency {:.2f}, #q: {} - {}".format(
|
||||||
|
acc, latency, len(prompts), subject
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return cors, acc, latency
|
return cors, acc, latency
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
subjects = sorted([f.split("_test.csv")[0] for f in os.listdir(os.path.join(args.data_dir, "test")) if "_test.csv" in f])
|
subjects = sorted(
|
||||||
|
[
|
||||||
|
f.split("_test.csv")[0]
|
||||||
|
for f in os.listdir(os.path.join(args.data_dir, "test"))
|
||||||
|
if "_test.csv" in f
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
all_cors = []
|
all_cors = []
|
||||||
all_latencies = []
|
all_latencies = []
|
||||||
num_requests = 0
|
num_requests = 0
|
||||||
|
|
||||||
for subject in tqdm(subjects[:args.nsub]):
|
for subject in tqdm(subjects[: args.nsub]):
|
||||||
dev_df = pd.read_csv(os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None)[:args.ntrain]
|
dev_df = pd.read_csv(
|
||||||
test_df = pd.read_csv(os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None)
|
os.path.join(args.data_dir, "dev", subject + "_dev.csv"), header=None
|
||||||
|
)[: args.ntrain]
|
||||||
|
test_df = pd.read_csv(
|
||||||
|
os.path.join(args.data_dir, "test", subject + "_test.csv"), header=None
|
||||||
|
)
|
||||||
|
|
||||||
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
cors, acc, latency = evaluate(args, subject, dev_df, test_df)
|
||||||
all_cors.append(cors)
|
all_cors.append(cors)
|
||||||
@@ -134,7 +163,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"nsub": args.nsub,
|
"nsub": args.nsub,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from fastchat.model import get_conversation_template
|
from fastchat.model import get_conversation_template
|
||||||
import requests
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_questions(filename):
|
def load_questions(filename):
|
||||||
@@ -38,7 +43,7 @@ def write_answers(filename, model_id, questions, answers):
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
questions = load_questions(args.question_file)
|
questions = load_questions(args.question_file)
|
||||||
questions = (questions * 10)[:args.num_questions]
|
questions = (questions * 10)[: args.num_questions]
|
||||||
max_tokens = 256
|
max_tokens = 256
|
||||||
model_id = "llama-2-chat"
|
model_id = "llama-2-chat"
|
||||||
|
|
||||||
@@ -68,8 +73,7 @@ def main(args):
|
|||||||
conv.append_message(conv.roles[1], None)
|
conv.append_message(conv.roles[1], None)
|
||||||
|
|
||||||
prompt = conv.get_prompt()
|
prompt = conv.get_prompt()
|
||||||
output = call_generate(prompt,
|
output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()
|
||||||
temperature=0, max_tokens=max_tokens).strip()
|
|
||||||
|
|
||||||
cur_answers.append(output)
|
cur_answers.append(output)
|
||||||
conv.update_last_message(output)
|
conv.update_last_message(output)
|
||||||
@@ -102,7 +106,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,10 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_questions(filename):
|
def load_questions(filename):
|
||||||
@@ -44,10 +47,9 @@ def answer_mt_bench(s, question_1, question_2):
|
|||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
questions = load_questions(args.question_file)[:args.num_questions]
|
questions = load_questions(args.question_file)[: args.num_questions]
|
||||||
arguments = [
|
arguments = [
|
||||||
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
|
{"question_1": q["turns"][0], "question_2": q["turns"][1]} for q in questions
|
||||||
for q in questions
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -83,7 +85,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,28 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
import asyncio
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -44,14 +49,20 @@ def multi_chain_gsm8k(question, num_chains, call_generate):
|
|||||||
|
|
||||||
comps = []
|
comps = []
|
||||||
for i in range(num_chains):
|
for i in range(num_chains):
|
||||||
comps.append(call_generate(s + "Answer: " + prompt_lib[i % num_chains],
|
comps.append(
|
||||||
max_tokens=256, temperature=0.3, stop="Question"))
|
call_generate(
|
||||||
|
s + "Answer: " + prompt_lib[i % num_chains],
|
||||||
|
max_tokens=256,
|
||||||
|
temperature=0.3,
|
||||||
|
stop="Question",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
s += "Answer: To answer this question, here are some possible solutions. "
|
s += "Answer: To answer this question, here are some possible solutions. "
|
||||||
s += "After considering all of them, I will do a majority vote.\n\n"
|
s += "After considering all of them, I will do a majority vote.\n\n"
|
||||||
for i in range(num_chains):
|
for i in range(num_chains):
|
||||||
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
s += f"Solution {i+1}: " + comps[i].strip() + "\n\n"
|
||||||
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
||||||
s += call_generate(s, max_tokens=16, temperature=0, stop=None)
|
s += call_generate(s, max_tokens=16, temperature=0, stop=None)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
@@ -64,7 +75,7 @@ def main(args):
|
|||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -82,16 +93,28 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop):
|
def call_generate(prompt, temperature, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
#def multi_chain_gsm8k(question, num_chains, call_generate):
|
# def multi_chain_gsm8k(question, num_chains, call_generate):
|
||||||
# s = model + "Question: " + question + "\n"
|
# s = model + "Question: " + question + "\n"
|
||||||
|
|
||||||
# comps = []
|
# comps = []
|
||||||
@@ -108,8 +131,10 @@ def main(args):
|
|||||||
|
|
||||||
elif args.backend == "lmql":
|
elif args.backend == "lmql":
|
||||||
import lmql
|
import lmql
|
||||||
model = lmql.model("meta-llama/Llama-2-7b-chat-hf",
|
|
||||||
endpoint=f"{args.host}:{args.port}")
|
model = lmql.model(
|
||||||
|
"meta-llama/Llama-2-7b-chat-hf", endpoint=f"{args.host}:{args.port}"
|
||||||
|
)
|
||||||
|
|
||||||
@lmql.query(model=model)
|
@lmql.query(model=model)
|
||||||
async def program(question):
|
async def program(question):
|
||||||
@@ -128,8 +153,7 @@ def main(args):
|
|||||||
if args.backend != "lmql":
|
if args.backend != "lmql":
|
||||||
# Use thread pool
|
# Use thread pool
|
||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
answer = multi_chain_gsm8k(questions[i], args.num_chains,
|
answer = multi_chain_gsm8k(questions[i], args.num_chains, call_generate)
|
||||||
call_generate)
|
|
||||||
states[i] = answer
|
states[i] = answer
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
@@ -144,12 +168,18 @@ def main(args):
|
|||||||
async def batched_call(batch_size):
|
async def batched_call(batch_size):
|
||||||
for i in range(0, len(questions), batch_size):
|
for i in range(0, len(questions), batch_size):
|
||||||
tasks = []
|
tasks = []
|
||||||
for q in questions[i:i+batch_size]:
|
for q in questions[i : i + batch_size]:
|
||||||
tasks.append(call_generate(few_shot_examples + q,
|
tasks.append(
|
||||||
temperature=0, max_tokens=256, stop="Question"))
|
call_generate(
|
||||||
|
few_shot_examples + q,
|
||||||
|
temperature=0,
|
||||||
|
max_tokens=256,
|
||||||
|
stop="Question",
|
||||||
|
)
|
||||||
|
)
|
||||||
rets = await asyncio.gather(*tasks)
|
rets = await asyncio.gather(*tasks)
|
||||||
for j in range(len(rets)):
|
for j in range(len(rets)):
|
||||||
states[i+j] = get_answer_value(rets[j])
|
states[i + j] = get_answer_value(rets[j])
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
asyncio.run(batched_call(batch_size=args.parallel))
|
asyncio.run(batched_call(batch_size=args.parallel))
|
||||||
@@ -180,7 +210,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -5,16 +5,19 @@ import re
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -37,12 +40,12 @@ def main(args):
|
|||||||
lines = read_jsonl(args.data_path)
|
lines = read_jsonl(args.data_path)
|
||||||
|
|
||||||
# Construct prompts
|
# Construct prompts
|
||||||
#k = args.num_shot
|
# k = args.num_shot
|
||||||
#few_shot_examples = get_few_shot_examples(lines, k)
|
# few_shot_examples = get_few_shot_examples(lines, k)
|
||||||
|
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -59,21 +62,24 @@ def main(args):
|
|||||||
@sgl.function
|
@sgl.function
|
||||||
def multi_chain_gsm8k(s, question):
|
def multi_chain_gsm8k(s, question):
|
||||||
s += "Question: " + question + "\n"
|
s += "Question: " + question + "\n"
|
||||||
#s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
|
# s += "Answer: " + prompt_lib[0] + sgl.gen("answer", max_tokens=256, stop="Question",
|
||||||
# temperature=0)
|
# temperature=0)
|
||||||
#return
|
# return
|
||||||
|
|
||||||
forks = s.fork(num_chains)
|
forks = s.fork(num_chains)
|
||||||
for i in range(num_chains):
|
for i in range(num_chains):
|
||||||
forks[i] += ("Answer: " + prompt_lib[i % num_chains] +
|
forks[i] += (
|
||||||
sgl.gen(f"chain", max_tokens=256, temperature=0.3, stop="Question"))
|
"Answer: "
|
||||||
|
+ prompt_lib[i % num_chains]
|
||||||
|
+ sgl.gen("chain", max_tokens=256, temperature=0.3, stop="Question")
|
||||||
|
)
|
||||||
forks.join()
|
forks.join()
|
||||||
|
|
||||||
s += "Answer: To answer this question, here are some possible solutions. "
|
s += "Answer: To answer this question, here are some possible solutions. "
|
||||||
s += "After considering all of them, I will do a majority vote.\n\n"
|
s += "After considering all of them, I will do a majority vote.\n\n"
|
||||||
for i in range(num_chains):
|
for i in range(num_chains):
|
||||||
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
|
s += f"Solution {i+1}: " + forks[i]["chain"].strip() + "\n\n"
|
||||||
s += f"\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
s += "\nBy considering the above solutions and doing a majority vote, I think the final answer (a single integer number) is "
|
||||||
s += sgl.gen("answer", max_tokens=16)
|
s += sgl.gen("answer", max_tokens=16)
|
||||||
|
|
||||||
#####################################
|
#####################################
|
||||||
@@ -86,7 +92,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = multi_chain_gsm8k.run_batch(
|
states = multi_chain_gsm8k.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
preds = []
|
preds = []
|
||||||
@@ -114,7 +125,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
USER_PREFIX = "[INST] "
|
USER_PREFIX = "[INST] "
|
||||||
USER_SUFFIX = " [/INST]"
|
USER_SUFFIX = " [/INST]"
|
||||||
@@ -25,7 +28,11 @@ def multi_document_qa(docs, question, generate):
|
|||||||
s += "".join(docs)
|
s += "".join(docs)
|
||||||
|
|
||||||
s += "\nDocuments end."
|
s += "\nDocuments end."
|
||||||
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
|
s += (
|
||||||
|
"\n\nBased on the above documents, please answer this question:\n"
|
||||||
|
+ question
|
||||||
|
+ "\nAnswer in three words or fewer."
|
||||||
|
)
|
||||||
s += USER_SUFFIX
|
s += USER_SUFFIX
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
answer = generate(s, max_tokens=16, stop=None)
|
answer = generate(s, max_tokens=16, stop=None)
|
||||||
@@ -42,11 +49,13 @@ def main(args):
|
|||||||
if args.backend == "guidance":
|
if args.backend == "guidance":
|
||||||
num_docs = 7 # due to OOM
|
num_docs = 7 # due to OOM
|
||||||
|
|
||||||
for i in range(len(l["questions"][:args.num_questions])):
|
for i in range(len(l["questions"][: args.num_questions])):
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"docs": l["documents"][:num_docs],
|
{
|
||||||
"question": l["questions"][i],
|
"docs": l["documents"][:num_docs],
|
||||||
})
|
"question": l["questions"][i],
|
||||||
|
}
|
||||||
|
)
|
||||||
labels.append(l["answers"][i])
|
labels.append(l["answers"][i])
|
||||||
states = [None] * len(arguments)
|
states = [None] * len(arguments)
|
||||||
|
|
||||||
@@ -61,13 +70,20 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf", n_gpu_layers=-1, n_ctx=11000)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/CodeLlama-7b-instruct-hf.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=11000,
|
||||||
|
)
|
||||||
|
|
||||||
def generate(prompt, max_tokens, stop):
|
def generate(prompt, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -113,7 +129,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,12 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
@@ -19,7 +21,11 @@ def multi_document_qa(s, docs, question):
|
|||||||
forks.join("concate_and_append")
|
forks.join("concate_and_append")
|
||||||
|
|
||||||
s += "\nDocuments end."
|
s += "\nDocuments end."
|
||||||
s += ("\n\nBased on the above documents, please answer this question:\n" + question + "\nAnswer in three words or fewer.")
|
s += (
|
||||||
|
"\n\nBased on the above documents, please answer this question:\n"
|
||||||
|
+ question
|
||||||
|
+ "\nAnswer in three words or fewer."
|
||||||
|
)
|
||||||
s += sgl.user_end()
|
s += sgl.user_end()
|
||||||
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
|
s += sgl.assistant(sgl.gen("answer", max_tokens=16))
|
||||||
|
|
||||||
@@ -29,11 +35,13 @@ def main(args):
|
|||||||
l = lines[0]
|
l = lines[0]
|
||||||
arguments = []
|
arguments = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(l["questions"][:args.num_questions])):
|
for i in range(len(l["questions"][: args.num_questions])):
|
||||||
arguments.append({
|
arguments.append(
|
||||||
"docs": l["documents"][:10],
|
{
|
||||||
"question": l["questions"][i],
|
"docs": l["documents"][:10],
|
||||||
})
|
"question": l["questions"][i],
|
||||||
|
}
|
||||||
|
)
|
||||||
labels.append(l["answers"][i])
|
labels.append(l["answers"][i])
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -43,7 +51,8 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = multi_document_qa.run_batch(
|
states = multi_document_qa.run_batch(
|
||||||
arguments, temperature=0, num_threads=args.parallel, progress_bar=True)
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
@@ -71,7 +80,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ import json
|
|||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
content = "\n".join(
|
content = "\n".join(
|
||||||
open("llama2.txt", 'r', encoding='utf-8', errors='ignore').readlines())
|
open("llama2.txt", "r", encoding="utf-8", errors="ignore").readlines()
|
||||||
|
)
|
||||||
content = content.replace("\n\n", "\n")
|
content = content.replace("\n\n", "\n")
|
||||||
|
|
||||||
# Count token
|
# Count token
|
||||||
@@ -35,30 +36,35 @@ for i, s in enumerate(segments):
|
|||||||
|
|
||||||
# Dump
|
# Dump
|
||||||
with open("questions.jsonl", "w") as fout:
|
with open("questions.jsonl", "w") as fout:
|
||||||
fout.write(json.dumps({
|
fout.write(
|
||||||
"documents": segments[:30],
|
json.dumps(
|
||||||
"questions": [
|
{
|
||||||
"What is the name of the fine-tuned LLMs?",
|
"documents": segments[:30],
|
||||||
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
|
"questions": [
|
||||||
"What is the number of parameters in the largest Llama 2 model?",
|
"What is the name of the fine-tuned LLMs?",
|
||||||
"What is the batch size of fine-tuning?",
|
"Which figure shows the helpfulness human evaluation results for Llama 2-Chat?",
|
||||||
"Where can we find the details of potential data contamination?",
|
"What is the number of parameters in the largest Llama 2 model?",
|
||||||
"What is the full name of MPT?",
|
"What is the batch size of fine-tuning?",
|
||||||
"What is the power consumption of RSC in Watt?",
|
"Where can we find the details of potential data contamination?",
|
||||||
"How many tokens of data do they train on?",
|
"What is the full name of MPT?",
|
||||||
"Which model's release is delayed due to a lack of time to sufficiently red team?",
|
"What is the power consumption of RSC in Watt?",
|
||||||
"Which activation function is used in Llama?"
|
"How many tokens of data do they train on?",
|
||||||
],
|
"Which model's release is delayed due to a lack of time to sufficiently red team?",
|
||||||
"answers": [
|
"Which activation function is used in Llama?",
|
||||||
"Llama 2 Chat",
|
],
|
||||||
"1",
|
"answers": [
|
||||||
"70 B",
|
"Llama 2 Chat",
|
||||||
"64",
|
"1",
|
||||||
"A 6",
|
"70 B",
|
||||||
"MosaicML",
|
"64",
|
||||||
"400",
|
"A 6",
|
||||||
"2 trillion",
|
"MosaicML",
|
||||||
"34 B",
|
"400",
|
||||||
"SwiGLU",
|
"2 trillion",
|
||||||
],
|
"34 B",
|
||||||
}) + "\n")
|
"SwiGLU",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ from argparse import ArgumentParser
|
|||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse
|
from data_gen import gen_arguments
|
||||||
from sglang.utils import dump_state_text
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
from data_gen import gen_arguments
|
from sglang.test.test_utils import add_common_other_args_and_parse
|
||||||
|
from sglang.utils import dump_state_text
|
||||||
|
|
||||||
|
|
||||||
def get_generate(args):
|
def get_generate(args):
|
||||||
|
|||||||
@@ -2,15 +2,15 @@ import json
|
|||||||
import time
|
import time
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
|
from data_gen import gen_arguments
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from sglang.utils import dump_state_text
|
from sglang.utils import dump_state_text
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
|
||||||
|
|
||||||
from data_gen import gen_arguments
|
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
@@ -29,7 +29,11 @@ def main(args):
|
|||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = multi_turns.run_batch(
|
states = multi_turns.run_batch(
|
||||||
multi_qas, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True
|
multi_qas,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
)
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
import argparse
|
import argparse
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
add_common_other_args_and_parse,
|
add_common_other_args_and_parse,
|
||||||
call_generate_lightllm,
|
call_generate_lightllm,
|
||||||
call_generate_vllm,
|
|
||||||
call_generate_srt_raw,
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
)
|
)
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(question):
|
def get_prompt(question):
|
||||||
@@ -83,16 +84,15 @@ Action 2: Search[Leonid Levin]
|
|||||||
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
||||||
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
||||||
Action 3: Finish[yes]
|
Action 3: Finish[yes]
|
||||||
""" + question)
|
"""
|
||||||
|
+ question
|
||||||
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
arguments = [{
|
arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
|
||||||
"question": k,
|
|
||||||
"triplets": v
|
|
||||||
} for l in lines for k, v in l.items()]
|
|
||||||
|
|
||||||
states = []
|
states = []
|
||||||
|
|
||||||
@@ -107,7 +107,7 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp(
|
model = models.LlamaCpp(
|
||||||
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
str(Path.home()) + "/model_weights/Llama-2-7b-chat.gguf",
|
||||||
@@ -116,12 +116,16 @@ def main(args):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop):
|
def call_generate(prompt, temperature, max_tokens, stop):
|
||||||
out = (model + prompt + gen(
|
out = (
|
||||||
name="result",
|
model
|
||||||
max_tokens=max_tokens,
|
+ prompt
|
||||||
temperature=temperature,
|
+ gen(
|
||||||
stop=stop,
|
name="result",
|
||||||
))
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
return out["result"]
|
return out["result"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -137,15 +141,23 @@ def main(args):
|
|||||||
for i in range(1, len(triplets) + 2):
|
for i in range(1, len(triplets) + 2):
|
||||||
prompt += "Thought " + str(i) + ":"
|
prompt += "Thought " + str(i) + ":"
|
||||||
states.append(prompt)
|
states.append(prompt)
|
||||||
answer = call_generate(prompt,
|
answer = call_generate(
|
||||||
max_tokens=200,
|
prompt, max_tokens=200, temperature=0, stop="Observation"
|
||||||
temperature=0,
|
)
|
||||||
stop="Observation")
|
|
||||||
if i > len(triplets):
|
if i > len(triplets):
|
||||||
break
|
break
|
||||||
prompt += (triplets[i - 1]["thought"] + "\nAction " + str(i) +
|
prompt += (
|
||||||
":" + triplets[i - 1]["action"] + "\nObservation " +
|
triplets[i - 1]["thought"]
|
||||||
str(i) + ":" + triplets[i - 1]["observation"] + "\n")
|
+ "\nAction "
|
||||||
|
+ str(i)
|
||||||
|
+ ":"
|
||||||
|
+ triplets[i - 1]["action"]
|
||||||
|
+ "\nObservation "
|
||||||
|
+ str(i)
|
||||||
|
+ ":"
|
||||||
|
+ triplets[i - 1]["observation"]
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
states.append(answer)
|
states.append(answer)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from sglang.test.test_utils import (
|
|||||||
add_common_sglang_args_and_parse,
|
add_common_sglang_args_and_parse,
|
||||||
select_sglang_backend,
|
select_sglang_backend,
|
||||||
)
|
)
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
@@ -79,7 +79,9 @@ Action 2: Search[Leonid Levin]
|
|||||||
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
Observation 2: Leonid Anatolievich Levin is a Soviet-American mathematician and computer scientist.
|
||||||
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
Thought 3: Leonid Levin is a mathematician and computer scientist. So Pavel Urysohn and Leonid Levin have the same type of work.
|
||||||
Action 3: Finish[yes]
|
Action 3: Finish[yes]
|
||||||
""" + question)
|
"""
|
||||||
|
+ question
|
||||||
|
)
|
||||||
for i in range(1, len(triplets) + 2):
|
for i in range(1, len(triplets) + 2):
|
||||||
s += "Thought " + str(i) + ":"
|
s += "Thought " + str(i) + ":"
|
||||||
# NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation.
|
# NOTE: This is an implementation for replaying a given trace for benchmark purposes. It is not an actual ReAct agent implementation.
|
||||||
@@ -90,17 +92,23 @@ Action 3: Finish[yes]
|
|||||||
# print(ss[0]["thought_action"])
|
# print(ss[0]["thought_action"])
|
||||||
if i > len(triplets):
|
if i > len(triplets):
|
||||||
break
|
break
|
||||||
s += (triplets[i - 1]["thought"] + "\nAction " + str(i) + ":" +
|
s += (
|
||||||
triplets[i - 1]["action"] + "\nObservation " + str(i) + ":" +
|
triplets[i - 1]["thought"]
|
||||||
triplets[i - 1]["observation"] + "\n")
|
+ "\nAction "
|
||||||
|
+ str(i)
|
||||||
|
+ ":"
|
||||||
|
+ triplets[i - 1]["action"]
|
||||||
|
+ "\nObservation "
|
||||||
|
+ str(i)
|
||||||
|
+ ":"
|
||||||
|
+ triplets[i - 1]["observation"]
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
arguments = [{
|
arguments = [{"question": k, "triplets": v} for l in lines for k, v in l.items()]
|
||||||
"question": k,
|
|
||||||
"triplets": v
|
|
||||||
} for l in lines for k, v in l.items()]
|
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
backend = select_sglang_backend(args)
|
backend = select_sglang_backend(args)
|
||||||
@@ -108,11 +116,12 @@ def main(args):
|
|||||||
|
|
||||||
states = []
|
states = []
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = webthink.run_batch(arguments,
|
states = webthink.run_batch(
|
||||||
temperature=0,
|
arguments,
|
||||||
num_threads=args.parallel,
|
temperature=0,
|
||||||
progress_bar=True,
|
num_threads=args.parallel,
|
||||||
)
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
number = 5
|
number = 5
|
||||||
|
|
||||||
|
|
||||||
def expand_tip(topic, tip, generate):
|
def expand_tip(topic, tip, generate):
|
||||||
s = (
|
s = (
|
||||||
"""Please expand a tip for a topic into a detailed paragraph.
|
"""Please expand a tip for a topic into a detailed paragraph.
|
||||||
|
|
||||||
Topic: staying healthy
|
Topic: staying healthy
|
||||||
Tip: Regular Exercise
|
Tip: Regular Exercise
|
||||||
@@ -30,14 +33,23 @@ Topic: writing a blog post
|
|||||||
Tip: structure your content effectively
|
Tip: structure your content effectively
|
||||||
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
||||||
|
|
||||||
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
|
Topic: """
|
||||||
|
+ topic
|
||||||
|
+ "\nTip: "
|
||||||
|
+ tip
|
||||||
|
+ "\nParagraph:"
|
||||||
|
)
|
||||||
return generate(s, max_tokens=128, stop=["\n\n"])
|
return generate(s, max_tokens=128, stop=["\n\n"])
|
||||||
|
|
||||||
|
|
||||||
def suggest_tips(topic, generate):
|
def suggest_tips(topic, generate):
|
||||||
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
s = "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
||||||
s += "USER: Give some tips for " + topic + ".\n"
|
s += "USER: Give some tips for " + topic + ".\n"
|
||||||
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
|
s += (
|
||||||
|
"ASSISTANT: Okay. Here are "
|
||||||
|
+ str(number)
|
||||||
|
+ " concise tips, each under 8 words:\n"
|
||||||
|
)
|
||||||
|
|
||||||
tips = []
|
tips = []
|
||||||
for i in range(1, 1 + number):
|
for i in range(1, 1 + number):
|
||||||
@@ -49,12 +61,12 @@ def suggest_tips(topic, generate):
|
|||||||
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
|
paragraphs = [expand_tip(topic, tip, generate=generate) for tip in tips]
|
||||||
|
|
||||||
for i in range(1, 1 + number):
|
for i in range(1, 1 + number):
|
||||||
s += f"Tip {i}:" + paragraphs[i-1] + "\n"
|
s += f"Tip {i}:" + paragraphs[i - 1] + "\n"
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
states = [None] * len(lines)
|
states = [None] * len(lines)
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
@@ -68,13 +80,20 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
generate = partial(call_generate_srt_raw, url=url, temperature=0)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def generate(prompt, max_tokens, stop):
|
def generate(prompt, max_tokens, stop):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=0, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(name="answer", max_tokens=max_tokens, temperature=0, stop=stop)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
|
|
||||||
# warmup
|
# warmup
|
||||||
@@ -111,7 +130,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -2,11 +2,12 @@ import argparse
|
|||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import sglang as sgl
|
import sglang as sgl
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
from sglang.test.test_utils import (
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
number = 5
|
number = 5
|
||||||
|
|
||||||
@@ -14,7 +15,7 @@ number = 5
|
|||||||
@sgl.function
|
@sgl.function
|
||||||
def expand_tip(s, topic, tip):
|
def expand_tip(s, topic, tip):
|
||||||
s += (
|
s += (
|
||||||
"""Please expand a tip for a topic into a detailed paragraph.
|
"""Please expand a tip for a topic into a detailed paragraph.
|
||||||
|
|
||||||
Topic: staying healthy
|
Topic: staying healthy
|
||||||
Tip: Regular Exercise
|
Tip: Regular Exercise
|
||||||
@@ -28,7 +29,12 @@ Topic: writing a blog post
|
|||||||
Tip: structure your content effectively
|
Tip: structure your content effectively
|
||||||
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
Paragraph: A well-structured post is easier to read and more enjoyable. Start with an engaging introduction that hooks the reader and clearly states the purpose of your post. Use headings and subheadings to break up the text and guide readers through your content. Bullet points and numbered lists can make information more digestible. Ensure each paragraph flows logically into the next, and conclude with a summary or call-to-action that encourages reader engagement.
|
||||||
|
|
||||||
Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
|
Topic: """
|
||||||
|
+ topic
|
||||||
|
+ "\nTip: "
|
||||||
|
+ tip
|
||||||
|
+ "\nParagraph:"
|
||||||
|
)
|
||||||
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
|
s += sgl.gen("paragraph", max_tokens=128, stop=["\n\n"], temperature=0)
|
||||||
|
|
||||||
|
|
||||||
@@ -36,7 +42,11 @@ Topic: """ + topic + "\nTip: " + tip + "\nParagraph:")
|
|||||||
def suggest_tips(s, topic):
|
def suggest_tips(s, topic):
|
||||||
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
s += "Please act as a helpful assistant. Your job is to provide users with useful tips on a specific topic.\n"
|
||||||
s += "USER: Give some tips for " + topic + ".\n"
|
s += "USER: Give some tips for " + topic + ".\n"
|
||||||
s += ("ASSISTANT: Okay. Here are " + str(number) + " concise tips, each under 8 words:\n")
|
s += (
|
||||||
|
"ASSISTANT: Okay. Here are "
|
||||||
|
+ str(number)
|
||||||
|
+ " concise tips, each under 8 words:\n"
|
||||||
|
)
|
||||||
|
|
||||||
paragraphs = []
|
paragraphs = []
|
||||||
for i in range(1, 1 + number):
|
for i in range(1, 1 + number):
|
||||||
@@ -44,14 +54,12 @@ def suggest_tips(s, topic):
|
|||||||
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
|
paragraphs.append(expand_tip(topic=topic, tip=s[f"tip_{i}"]))
|
||||||
|
|
||||||
for i in range(1, 1 + number):
|
for i in range(1, 1 + number):
|
||||||
s += f"Tip {i}:" + paragraphs[i-1]["paragraph"] + "\n"
|
s += f"Tip {i}:" + paragraphs[i - 1]["paragraph"] + "\n"
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)[:args.num_questions]
|
lines = read_jsonl(args.data_path)[: args.num_questions]
|
||||||
arguments = [
|
arguments = [{"topic": l["topic"]} for l in lines]
|
||||||
{"topic": l["topic"]} for l in lines
|
|
||||||
]
|
|
||||||
|
|
||||||
# Select backend
|
# Select backend
|
||||||
sgl.set_default_backend(select_sglang_backend(args))
|
sgl.set_default_backend(select_sglang_backend(args))
|
||||||
@@ -59,7 +67,8 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = suggest_tips.run_batch(
|
states = suggest_tips.run_batch(
|
||||||
arguments, temperature=0, num_threads=args.parallel, progress_bar=True)
|
arguments, temperature=0, num_threads=args.parallel, progress_bar=True
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
|
|
||||||
# Compute accuracy
|
# Compute accuracy
|
||||||
@@ -78,7 +87,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,29 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
|
||||||
from collections import Counter
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -47,35 +51,56 @@ temp = 0.001
|
|||||||
|
|
||||||
|
|
||||||
def propose_plan(s, question, num_branches, call_generate):
|
def propose_plan(s, question, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
||||||
|
+ question
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
def execute_plan(s, num_branches, call_generate):
|
def execute_plan(s, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
def reflect_solution(s, num_branches, call_generate):
|
def reflect_solution(s, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
def get_final_answer(s, num_branches, call_generate):
|
def get_final_answer(s, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""" + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
@@ -107,7 +132,7 @@ def main(args):
|
|||||||
num_branches = 2
|
num_branches = 2
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -124,20 +149,40 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop, n):
|
def call_generate(prompt, temperature, max_tokens, stop, n):
|
||||||
if n == 1:
|
if n == 1:
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
else:
|
else:
|
||||||
rets = []
|
rets = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
rets.append(out["answer"])
|
rets.append(out["answer"])
|
||||||
return rets
|
return rets
|
||||||
|
|
||||||
@@ -146,6 +191,7 @@ def main(args):
|
|||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
states = [None] * len(questions)
|
states = [None] * len(questions)
|
||||||
|
|
||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
||||||
|
|
||||||
@@ -188,7 +234,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
from collections import Counter
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
import sglang as sgl
|
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -40,7 +43,9 @@ temp = 0.001
|
|||||||
|
|
||||||
def propose_plan(s, question, num_branches):
|
def propose_plan(s, question, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
|
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
||||||
|
+ question
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches):
|
|||||||
|
|
||||||
def execute_plan(s, num_branches):
|
def execute_plan(s, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
|
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -56,7 +62,8 @@ def execute_plan(s, num_branches):
|
|||||||
|
|
||||||
def reflect_solution(s, num_branches):
|
def reflect_solution(s, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
|
"""Okay. Now, evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -64,13 +71,13 @@ def reflect_solution(s, num_branches):
|
|||||||
|
|
||||||
def get_final_answer(s, num_branches):
|
def get_final_answer(s, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration.""")
|
"""Based on your reflection, do you change your mind? Now, give me the final answer after careful consideration."""
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("final_answer", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@sgl.function
|
@sgl.function
|
||||||
def tree_search(s, question, num_branches):
|
def tree_search(s, question, num_branches):
|
||||||
plan_forks = propose_plan(s, question, num_branches)
|
plan_forks = propose_plan(s, question, num_branches)
|
||||||
@@ -93,6 +100,7 @@ def tree_search(s, question, num_branches):
|
|||||||
|
|
||||||
return solutions
|
return solutions
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
lines = read_jsonl(args.data_path)
|
lines = read_jsonl(args.data_path)
|
||||||
|
|
||||||
@@ -100,7 +108,7 @@ def main(args):
|
|||||||
num_branches = 2
|
num_branches = 2
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -112,7 +120,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = tree_search.run_batch(
|
states = tree_search.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
answers_text = []
|
answers_text = []
|
||||||
for s in states:
|
for s in states:
|
||||||
@@ -144,7 +157,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,25 +1,29 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
import asyncio
|
|
||||||
from collections import Counter
|
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from functools import partial
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt_raw
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
|
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_other_args_and_parse,
|
||||||
|
call_generate_lightllm,
|
||||||
|
call_generate_srt_raw,
|
||||||
|
call_generate_vllm,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -47,27 +51,43 @@ temp = 0.3
|
|||||||
|
|
||||||
|
|
||||||
def propose_plan(s, question, num_branches, call_generate):
|
def propose_plan(s, question, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
||||||
|
+ question
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
def execute_plan(s, num_branches, call_generate):
|
def execute_plan(s, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""" + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
def reflect_solution(s, num_branches, call_generate):
|
def reflect_solution(s, num_branches, call_generate):
|
||||||
s += (USER_PREFIX +
|
s += (
|
||||||
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""" + USER_SUFFIX)
|
USER_PREFIX
|
||||||
|
+ """Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
||||||
|
+ USER_SUFFIX
|
||||||
|
)
|
||||||
s += ASSISTANT_PREFIX
|
s += ASSISTANT_PREFIX
|
||||||
comps = call_generate(s, max_tokens=256, temperature=temp, stop=None, n=num_branches)
|
comps = call_generate(
|
||||||
|
s, max_tokens=256, temperature=temp, stop=None, n=num_branches
|
||||||
|
)
|
||||||
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
return [s + comp + ASSISTANT_SUFFIX for comp in comps]
|
||||||
|
|
||||||
|
|
||||||
@@ -92,7 +112,7 @@ def main(args):
|
|||||||
num_branches = 3
|
num_branches = 3
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -109,25 +129,46 @@ def main(args):
|
|||||||
url = f"{args.host}:{args.port}/generate"
|
url = f"{args.host}:{args.port}/generate"
|
||||||
call_generate = partial(call_generate_srt_raw, url=url)
|
call_generate = partial(call_generate_srt_raw, url=url)
|
||||||
elif args.backend == "guidance":
|
elif args.backend == "guidance":
|
||||||
from guidance import models, gen
|
from guidance import gen, models
|
||||||
|
|
||||||
model = models.LlamaCpp("/home/ubuntu/model_weights/Llama-2-7b-chat.gguf", n_gpu_layers=-1, n_ctx=4096)
|
model = models.LlamaCpp(
|
||||||
|
"/home/ubuntu/model_weights/Llama-2-7b-chat.gguf",
|
||||||
|
n_gpu_layers=-1,
|
||||||
|
n_ctx=4096,
|
||||||
|
)
|
||||||
|
|
||||||
def call_generate(prompt, temperature, max_tokens, stop, n):
|
def call_generate(prompt, temperature, max_tokens, stop, n):
|
||||||
if n == 1:
|
if n == 1:
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
return out["answer"]
|
return out["answer"]
|
||||||
else:
|
else:
|
||||||
rets = []
|
rets = []
|
||||||
for i in range(n):
|
for i in range(n):
|
||||||
out = model + prompt + gen(name="answer",
|
out = (
|
||||||
max_tokens=max_tokens, temperature=temperature, stop=stop)
|
model
|
||||||
|
+ prompt
|
||||||
|
+ gen(
|
||||||
|
name="answer",
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
stop=stop,
|
||||||
|
)
|
||||||
|
)
|
||||||
rets.append(out["answer"])
|
rets.append(out["answer"])
|
||||||
return rets
|
return rets
|
||||||
|
|
||||||
# Run requests
|
# Run requests
|
||||||
states = [None] * len(questions)
|
states = [None] * len(questions)
|
||||||
|
|
||||||
def get_one_answer(i):
|
def get_one_answer(i):
|
||||||
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
states[i] = tree_search(**arguments[i], call_generate=call_generate)
|
||||||
|
|
||||||
@@ -170,7 +211,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -1,22 +1,25 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import ast
|
import ast
|
||||||
from collections import Counter
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
|
||||||
from sglang.utils import read_jsonl, dump_state_text
|
|
||||||
import sglang as sgl
|
|
||||||
|
|
||||||
|
import sglang as sgl
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
add_common_sglang_args_and_parse,
|
||||||
|
select_sglang_backend,
|
||||||
|
)
|
||||||
|
from sglang.utils import dump_state_text, read_jsonl
|
||||||
|
|
||||||
INVALID = -9999999
|
INVALID = -9999999
|
||||||
|
|
||||||
|
|
||||||
def get_answer_value(answer_str):
|
def get_answer_value(answer_str):
|
||||||
answer_str = answer_str.replace(",", "")
|
answer_str = answer_str.replace(",", "")
|
||||||
numbers = re.findall(r'\d+', answer_str)
|
numbers = re.findall(r"\d+", answer_str)
|
||||||
if len(numbers) < 1:
|
if len(numbers) < 1:
|
||||||
return INVALID
|
return INVALID
|
||||||
try:
|
try:
|
||||||
@@ -40,7 +43,9 @@ temp = 0.3
|
|||||||
|
|
||||||
def propose_plan(s, question, num_branches):
|
def propose_plan(s, question, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """ + question)
|
"""Please generate a high-level plan for solving the following question. As the first step, just say what method and idea you will use to solve the question. You can reorganize the information in the question. Do not do the actual calculation. Keep your response concise and within 80 words. Question: """
|
||||||
|
+ question
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("plan", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -48,7 +53,8 @@ def propose_plan(s, question, num_branches):
|
|||||||
|
|
||||||
def execute_plan(s, num_branches):
|
def execute_plan(s, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short.""")
|
"""The plan looks good! Now, use real numbers and do the calculation. Please solve the question step-by-step according to the high-level plan. Give me the final answer. Make your response short."""
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("answer", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -56,7 +62,8 @@ def execute_plan(s, num_branches):
|
|||||||
|
|
||||||
def reflect_solution(s, num_branches):
|
def reflect_solution(s, num_branches):
|
||||||
s += sgl.user(
|
s += sgl.user(
|
||||||
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness.""")
|
"""Okay. Now you evaluate your own solution and give it a score on a scale of 1 to 5. Please do rigorous check of the correctness."""
|
||||||
|
)
|
||||||
forks = s.fork(num_branches)
|
forks = s.fork(num_branches)
|
||||||
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
forks += sgl.assistant(sgl.gen("score", max_tokens=256, temperature=temp))
|
||||||
return forks
|
return forks
|
||||||
@@ -90,7 +97,7 @@ def main(args):
|
|||||||
num_branches = 3
|
num_branches = 3
|
||||||
questions = []
|
questions = []
|
||||||
labels = []
|
labels = []
|
||||||
for i in range(len(lines[:args.num_questions])):
|
for i in range(len(lines[: args.num_questions])):
|
||||||
questions.append(lines[i]["question"])
|
questions.append(lines[i]["question"])
|
||||||
labels.append(get_answer_value(lines[i]["answer"]))
|
labels.append(get_answer_value(lines[i]["answer"]))
|
||||||
assert all(l != INVALID for l in labels)
|
assert all(l != INVALID for l in labels)
|
||||||
@@ -102,7 +109,12 @@ def main(args):
|
|||||||
# Run requests
|
# Run requests
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
states = tree_search.run_batch(
|
states = tree_search.run_batch(
|
||||||
arguments, temperature=0, backend=backend, num_threads=args.parallel, progress_bar=True)
|
arguments,
|
||||||
|
temperature=0,
|
||||||
|
backend=backend,
|
||||||
|
num_threads=args.parallel,
|
||||||
|
progress_bar=True,
|
||||||
|
)
|
||||||
latency = time.time() - tic
|
latency = time.time() - tic
|
||||||
answers_text = []
|
answers_text = []
|
||||||
for s in states:
|
for s in states:
|
||||||
@@ -134,7 +146,7 @@ def main(args):
|
|||||||
"other": {
|
"other": {
|
||||||
"num_questions": args.num_questions,
|
"num_questions": args.num_questions,
|
||||||
"parallel": args.parallel,
|
"parallel": args.parallel,
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
fout.write(json.dumps(value) + "\n")
|
fout.write(json.dumps(value) + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -3,3 +3,6 @@ black python
|
|||||||
|
|
||||||
isort test
|
isort test
|
||||||
black test
|
black test
|
||||||
|
|
||||||
|
isort benchmark
|
||||||
|
black benchmark
|
||||||
|
|||||||
Reference in New Issue
Block a user