misc: recommend to use chat model for benchmark (#690)
This commit is contained in:
@@ -630,6 +630,7 @@ async def benchmark(
|
||||
"random_input_len": args.random_input_len,
|
||||
"random_output_len": args.random_output_len,
|
||||
"random_range_ratio": args.random_range_ratio,
|
||||
"benchmark_duration": benchmark_duration,
|
||||
}
|
||||
else:
|
||||
print(f"Error running benchmark for request rate: {request_rate}")
|
||||
@@ -687,6 +688,15 @@ def parse_request_rate_range(request_rate_range):
|
||||
return list(map(int, request_rate_range.split(",")))
|
||||
|
||||
|
||||
def check_chat_template(model_path):
|
||||
try:
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
|
||||
return "chat_template" in tokenizer.init_kwargs
|
||||
except Exception as e:
|
||||
print(f"Fail to load tokenizer config with error={e}")
|
||||
return False
|
||||
|
||||
|
||||
def fire(args: argparse.Namespace):
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -736,6 +746,12 @@ def fire(args: argparse.Namespace):
|
||||
print("No model specified or found. Please provide a model using `--model`.")
|
||||
sys.exit(1)
|
||||
|
||||
if not check_chat_template(args.model):
|
||||
print(
|
||||
"\nWARNING It is recommended to use the `Chat` or `Instruct` model for benchmarking.\n"
|
||||
"Because when the tokenizer counts the output tokens, if there is gibberish, it might count incorrectly.\n"
|
||||
)
|
||||
|
||||
print(f"{args}\n")
|
||||
|
||||
backend = args.backend
|
||||
|
||||
Reference in New Issue
Block a user