Format Benchmark Code (#399)
This commit is contained in:
@@ -1,14 +1,19 @@
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
from fastchat.model import get_conversation_template
|
||||
import requests
|
||||
from sglang.test.test_utils import add_common_other_args_and_parse, call_generate_lightllm, call_generate_vllm, call_generate_srt
|
||||
|
||||
from sglang.test.test_utils import (
|
||||
add_common_other_args_and_parse,
|
||||
call_generate_lightllm,
|
||||
call_generate_srt,
|
||||
call_generate_vllm,
|
||||
)
|
||||
|
||||
|
||||
def load_questions(filename):
|
||||
@@ -38,7 +43,7 @@ def write_answers(filename, model_id, questions, answers):
|
||||
|
||||
def main(args):
|
||||
questions = load_questions(args.question_file)
|
||||
questions = (questions * 10)[:args.num_questions]
|
||||
questions = (questions * 10)[: args.num_questions]
|
||||
max_tokens = 256
|
||||
model_id = "llama-2-chat"
|
||||
|
||||
@@ -67,9 +72,8 @@ def main(args):
|
||||
conv.append_message(conv.roles[0], q)
|
||||
conv.append_message(conv.roles[1], None)
|
||||
|
||||
prompt = conv.get_prompt()
|
||||
output = call_generate(prompt,
|
||||
temperature=0, max_tokens=max_tokens).strip()
|
||||
prompt = conv.get_prompt()
|
||||
output = call_generate(prompt, temperature=0, max_tokens=max_tokens).strip()
|
||||
|
||||
cur_answers.append(output)
|
||||
conv.update_last_message(output)
|
||||
@@ -102,7 +106,7 @@ def main(args):
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ import time
|
||||
import uuid
|
||||
|
||||
import sglang as sgl
|
||||
from sglang.test.test_utils import add_common_sglang_args_and_parse, select_sglang_backend
|
||||
from sglang.test.test_utils import (
|
||||
add_common_sglang_args_and_parse,
|
||||
select_sglang_backend,
|
||||
)
|
||||
|
||||
|
||||
def load_questions(filename):
|
||||
@@ -44,10 +47,9 @@ def answer_mt_bench(s, question_1, question_2):
|
||||
|
||||
def main(args):
|
||||
# Construct prompts
|
||||
questions = load_questions(args.question_file)[:args.num_questions]
|
||||
questions = load_questions(args.question_file)[: args.num_questions]
|
||||
arguments = [
|
||||
{"question_1": q["turns"][0], "question_2": q["turns"][1]}
|
||||
for q in questions
|
||||
{"question_1": q["turns"][0], "question_2": q["turns"][1]} for q in questions
|
||||
]
|
||||
|
||||
# Select backend
|
||||
@@ -83,7 +85,7 @@ def main(args):
|
||||
"other": {
|
||||
"num_questions": args.num_questions,
|
||||
"parallel": args.parallel,
|
||||
}
|
||||
},
|
||||
}
|
||||
fout.write(json.dumps(value) + "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user